From 312d190535b1c77625d662d18313b90af64cb448 Mon Sep 17 00:00:00 2001 From: Zelda Hessler Date: Tue, 13 Jun 2023 18:49:43 -0500 Subject: [PATCH] Update standard orchestrator retries with token bucket and more tests (#2764) ## Motivation and Context addresses #2743 ## Description - add more standard retry tests - add optional standard retries token bucket ## Testing tests are included ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: John DiSanti --- .../aws-config/src/profile/credentials.rs | 1 - .../smithy/generators/OperationGenerator.kt | 28 +- .../ServiceRuntimePluginGenerator.kt | 22 +- .../src/client/orchestrator/error.rs | 35 +- .../src/client/retries.rs | 16 +- .../src/client/retries/rate_limiting.rs | 13 - .../src/client/retries/rate_limiting/error.rs | 50 -- .../src/client/retries/rate_limiting/token.rs | 65 --- .../retries/rate_limiting/token_bucket.rs | 235 --------- .../src/client/orchestrator.rs | 35 +- .../src/client/retries/classifier.rs | 32 +- .../src/client/retries/strategy.rs | 2 +- .../src/client/retries/strategy/standard.rs | 457 +++++++++++++++++- .../src/client/runtime_plugin.rs | 2 + .../runtime_plugin/standard_token_bucket.rs | 100 ++++ rust-runtime/aws-smithy-types/src/retry.rs | 17 + 16 files changed, 655 insertions(+), 455 deletions(-) delete mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting.rs delete mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/error.rs delete mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token.rs delete mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token_bucket.rs create mode 100644 rust-runtime/aws-smithy-runtime/src/client/runtime_plugin/standard_token_bucket.rs diff --git a/aws/rust-runtime/aws-config/src/profile/credentials.rs b/aws/rust-runtime/aws-config/src/profile/credentials.rs index 774d5785a..c3d08e58d 100644 --- a/aws/rust-runtime/aws-config/src/profile/credentials.rs +++ b/aws/rust-runtime/aws-config/src/profile/credentials.rs @@ -465,7 +465,6 @@ async fn build_provider_chain( #[cfg(test)] mod test { - use crate::profile::credentials::Builder; use crate::test_case::TestEnvironment; diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt index 859126d03..5703766c5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt @@ -146,20 +146,6 @@ open class OperationGenerator( if (codegenContext.smithyRuntimeMode.generateOrchestrator) { rustTemplate( """ - pub(crate) fn register_runtime_plugins( - runtime_plugins: #{RuntimePlugins}, - handle: #{Arc}, - config_override: #{Option}, - ) -> #{RuntimePlugins} { - #{register_default_runtime_plugins}( - runtime_plugins, - #{Box}::new(Self::new()) as _, - handle, - config_override - ) - #{additional_runtime_plugins} - } - pub(crate) async fn orchestrate( runtime_plugins: &#{RuntimePlugins}, input: #{Input}, @@ -186,6 +172,20 @@ open class OperationGenerator( let input = #{TypedBox}::new(input).erase(); #{invoke_with_stop_point}(input, runtime_plugins, stop_point).await } + + pub(crate) fn register_runtime_plugins( + runtime_plugins: #{RuntimePlugins}, + handle: #{Arc}, + config_override: #{Option}, + ) -> #{RuntimePlugins} { + #{register_default_runtime_plugins}( + runtime_plugins, + #{Box}::new(Self::new()) as _, + handle, + config_override + ) + #{additional_runtime_plugins} + } """, *codegenScope, "Error" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Error"), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt index 478da7fb9..033c60b16 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -105,6 +105,8 @@ class ServiceRuntimePluginGenerator( "StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"), "default_connector" to client.resolve("conns::default_connector"), "require_connector" to client.resolve("conns::require_connector"), + "TimeoutConfig" to smithyTypes.resolve("timeout::TimeoutConfig"), + "RetryConfig" to smithyTypes.resolve("retry::RetryConfig"), ) } @@ -142,20 +144,17 @@ class ServiceRuntimePluginGenerator( self.handle.conf.endpoint_resolver()); cfg.set_endpoint_resolver(endpoint_resolver); - // TODO(enableNewSmithyRuntime): Use the `store_append` method of ConfigBag to insert classifiers - let retry_classifiers = #{RetryClassifiers}::new() - #{retry_classifier_customizations}; - cfg.set_retry_classifiers(retry_classifiers); + // TODO(enableNewSmithyRuntime): Make it possible to set retry classifiers at the service level. + // Retry classifiers can also be set at the operation level and those should be added to the + // list of classifiers defined here, rather than replacing them. let sleep_impl = self.handle.conf.sleep_impl(); - let timeout_config = self.handle.conf.timeout_config(); - let retry_config = self.handle.conf.retry_config(); + let timeout_config = self.handle.conf.timeout_config().cloned().unwrap_or_else(|| #{TimeoutConfig}::disabled()); + let retry_config = self.handle.conf.retry_config().cloned().unwrap_or_else(|| #{RetryConfig}::disabled()); - if let Some(retry_config) = retry_config { - cfg.set_retry_strategy(#{StandardRetryStrategy}::new(retry_config)); - } + cfg.set_retry_strategy(#{StandardRetryStrategy}::new(&retry_config)); - let connector_settings = timeout_config.map(#{ConnectorSettings}::from_timeout_config).unwrap_or_default(); + let connector_settings = #{ConnectorSettings}::from_timeout_config(&timeout_config); if let Some(connection) = self.handle.conf.http_connector() .and_then(|c| c.connector(&connector_settings, sleep_impl.clone())) .or_else(|| #{default_connector}(&connector_settings, sleep_impl)) { @@ -180,9 +179,6 @@ class ServiceRuntimePluginGenerator( "http_auth_scheme_customizations" to writable { writeCustomizations(customizations, ServiceRuntimePluginSection.HttpAuthScheme("cfg")) }, - "retry_classifier_customizations" to writable { - writeCustomizations(customizations, ServiceRuntimePluginSection.RetryClassifier("cfg")) - }, "additional_config" to writable { writeCustomizations(customizations, ServiceRuntimePluginSection.AdditionalConfig("cfg")) }, diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs index 152e44201..06d240f7e 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs @@ -18,6 +18,12 @@ pub enum OrchestratorError { Interceptor { err: InterceptorError }, /// An error returned by a service. Operation { err: E }, + /// An error that occurs when a request times out. + Timeout { err: BoxError }, + /// An error that occurs when request dispatch fails. + Connector { err: ConnectorError }, + /// An error that occurs when a response can't be deserialized. + Response { err: BoxError }, /// A general orchestrator error. Other { err: BoxError }, } @@ -34,11 +40,26 @@ impl OrchestratorError { Self::Operation { err } } - /// Create a new `OrchestratorError` from an [`InterceptorError`]. + /// Create a new `OrchestratorError::Interceptor` from an [`InterceptorError`]. pub fn interceptor(err: InterceptorError) -> Self { Self::Interceptor { err } } + /// Create a new `OrchestratorError::Timeout` from a [`BoxError`]. + pub fn timeout(err: BoxError) -> Self { + Self::Timeout { err } + } + + /// Create a new `OrchestratorError::Response` from a [`BoxError`]. + pub fn response(err: BoxError) -> Self { + Self::Response { err } + } + + /// Create a new `OrchestratorError::Connector` from a [`ConnectorError`]. + pub fn connector(err: ConnectorError) -> Self { + Self::Connector { err } + } + /// Convert the `OrchestratorError` into `Some` operation specific error if it is one. Otherwise, /// return `None`. pub fn as_operation_error(&self) -> Option<&E> { @@ -72,6 +93,9 @@ impl OrchestratorError { debug_assert!(phase.is_after_deserialization(), "operation errors are a result of successfully receiving and parsing a response from the server. Therefore, we must be in the 'After Deserialization' phase."); SdkError::service_error(err, response.expect("phase has a response")) } + Self::Connector { err } => SdkError::dispatch_failure(err), + Self::Timeout { err } => SdkError::timeout_error(err), + Self::Response { err } => SdkError::response_error(err, response.unwrap()), Self::Other { err } => { use Phase::*; match phase { @@ -111,15 +135,6 @@ where } } -impl From for OrchestratorError -where - E: Debug + std::error::Error + 'static, -{ - fn from(err: BoxError) -> Self { - Self::other(err) - } -} - impl From for OrchestratorError { fn from(err: TypeErasedError) -> Self { Self::operation(err) diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs index ca9699075..1a8f0f8cf 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs @@ -20,6 +20,16 @@ pub enum ShouldAttempt { YesAfterDelay(Duration), } +#[cfg(feature = "test-util")] +impl ShouldAttempt { + pub fn expect_delay(self) -> Duration { + match self { + ShouldAttempt::YesAfterDelay(delay) => delay, + _ => panic!("Expected this to be the `YesAfterDelay` variant but it was the `{self:?}` variant instead"), + } + } +} + pub trait RetryStrategy: Send + Sync + Debug { fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result; @@ -31,7 +41,7 @@ pub trait RetryStrategy: Send + Sync + Debug { } #[non_exhaustive] -#[derive(Eq, PartialEq, Debug)] +#[derive(Clone, Eq, PartialEq, Debug)] pub enum RetryReason { Error(ErrorKind), Explicit(Duration), @@ -72,10 +82,10 @@ impl RetryClassifiers { } impl ClassifyRetry for RetryClassifiers { - fn classify_retry(&self, error: &InterceptorContext) -> Option { + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { // return the first non-None result self.inner.iter().find_map(|cr| { - let maybe_reason = cr.classify_retry(error); + let maybe_reason = cr.classify_retry(ctx); match maybe_reason.as_ref() { Some(reason) => trace!( diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting.rs deleted file mode 100644 index fcb6085fb..000000000 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting.rs +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -//! Code for rate-limiting smithy clients. - -pub mod error; -pub mod token; -pub mod token_bucket; - -pub use token::Token; -pub use token_bucket::TokenBucket; diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/error.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/error.rs deleted file mode 100644 index b4f4d9821..000000000 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/error.rs +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -//! Errors related to rate limiting - -use std::fmt; - -/// Errors related to a token bucket. -#[derive(Debug)] -pub struct RateLimitingError { - kind: ErrorKind, -} - -impl RateLimitingError { - /// An error that occurs when no tokens are left in the bucket. - pub fn no_tokens() -> Self { - Self { - kind: ErrorKind::NoTokens, - } - } - - /// An error that occurs due to a bug in the code. Please report bugs you encounter. - pub fn bug(s: impl ToString) -> Self { - Self { - kind: ErrorKind::Bug(s.to_string()), - } - } -} - -#[derive(Debug)] -enum ErrorKind { - /// A token was requested but there were no tokens left in the bucket. - NoTokens, - /// This error should never occur and is a bug. Please report it. - Bug(String), -} - -impl fmt::Display for RateLimitingError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use ErrorKind::*; - match &self.kind { - NoTokens => write!(f, "No more tokens are left in the bucket."), - Bug(msg) => write!(f, "you've encountered a bug that needs reporting: {}", msg), - } - } -} - -impl std::error::Error for RateLimitingError {} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token.rs deleted file mode 100644 index 70d620e79..000000000 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token.rs +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -//! Types and traits related to token buckets. Token buckets are used to limit the amount of -//! requests a client sends in order to avoid getting throttled. Token buckets can also act as a -//! form of concurrency control if a token is required to send a new request (as opposed to retry -//! requests only). - -use tokio::sync::OwnedSemaphorePermit; - -/// A trait implemented by types that represent a token dispensed from a [`TokenBucket`](super::TokenBucket). -pub trait Token { - /// Release this token back to the bucket. This should be called if the related request succeeds. - fn release(self); - - /// Forget this token, forever banishing it to the shadow realm, from whence no tokens return. - /// This should be called if the related request fails. - fn forget(self); -} - -/// The token type of [`Standard`]. -#[derive(Debug)] -pub struct Standard { - permit: Option, -} - -impl Standard { - pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self { - Self { - permit: Some(permit), - } - } - - // Return an "empty" token for times when you need to return a token but there's no "cost" - // associated with an action. - pub(crate) fn empty() -> Self { - Self { permit: None } - } -} - -impl Token for Standard { - fn release(self) { - drop(self.permit) - } - - fn forget(self) { - if let Some(permit) = self.permit { - permit.forget() - } - } -} - -#[cfg(test)] -mod tests { - use super::Standard as Token; - use crate::client::retries::rate_limiting::token_bucket::Standard as TokenBucket; - - #[test] - fn token_bucket_trait_is_dyn_safe() { - let _tb: Box> = - Box::new(TokenBucket::builder().build()); - } -} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token_bucket.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token_bucket.rs deleted file mode 100644 index 500ee723f..000000000 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries/rate_limiting/token_bucket.rs +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -//! A token bucket intended for use with the standard smithy client retry policy. - -use super::error::RateLimitingError; -use super::token; -use super::Token; -use aws_smithy_types::retry::{ErrorKind, RetryKind}; -use std::sync::Arc; -use tokio::sync::Semaphore; -use tokio::sync::TryAcquireError; - -/// The default number of tokens to start with -const STANDARD_INITIAL_RETRY_TOKENS: usize = 500; -/// The amount of tokens to remove from the bucket when a timeout error occurs -const STANDARD_TIMEOUT_ERROR_RETRY_COST: u32 = 10; -/// The amount of tokens to remove from the bucket when a throttling error occurs -const STANDARD_RETRYABLE_ERROR_RETRY_COST: u32 = 5; - -/// This trait is implemented by types that act as token buckets. Token buckets are used to regulate -/// the amount of requests sent by clients. Different token buckets may apply different strategies -/// to manage the number of tokens in a bucket. -/// -/// related: [`Token`], [`RateLimitingError`] -pub trait TokenBucket { - /// The type of tokens this bucket dispenses. - type Token: Token; - - /// Attempt to acquire a token from the bucket. This will fail if the bucket has no more tokens. - fn try_acquire( - &self, - previous_response_kind: Option, - ) -> Result; - - /// Get the number of available tokens in the bucket. - fn available(&self) -> usize; - - /// Refill the bucket with the given number of tokens. - fn refill(&self, tokens: usize); -} - -/// A token bucket implementation that uses a `tokio::sync::Semaphore` to track the number of tokens. -/// -/// - Whenever a request succeeds on the first try, `` token(s) -/// are added back to the bucket. -/// - When a request fails with a timeout error, `` token(s) -/// are removed from the bucket. -/// - When a request fails with a retryable error, `` token(s) -/// are removed from the bucket. -/// -/// The number of tokens in the bucket will always be >= `0` and <= ``. -#[derive(Clone, Debug)] -pub struct Standard { - inner: Arc, - max_tokens: usize, - timeout_error_cost: u32, - retryable_error_cost: u32, -} - -impl Standard { - /// Create a new `TokenBucket` using builder methods. - pub fn builder() -> Builder { - Builder::default() - } -} - -/// A builder for `TokenBucket`s. -#[derive(Default, Debug)] -pub struct Builder { - starting_tokens: Option, - max_tokens: Option, - timeout_error_cost: Option, - retryable_error_cost: Option, -} - -impl Builder { - /// The number of tokens the bucket will start with. Defaults to 500. - pub fn starting_tokens(mut self, starting_tokens: usize) -> Self { - self.starting_tokens = Some(starting_tokens); - self - } - - /// The maximum number of tokens that the bucket can hold. - /// Defaults to the value of `starting_tokens`. - pub fn max_tokens(mut self, max_tokens: usize) -> Self { - self.max_tokens = Some(max_tokens); - self - } - - /// How many tokens to remove from the bucket when a request fails due to a timeout error. - /// Defaults to 10. - pub fn timeout_error_cost(mut self, timeout_error_cost: u32) -> Self { - self.timeout_error_cost = Some(timeout_error_cost); - self - } - - /// How many tokens to remove from the bucket when a request fails due to a retryable error that - /// isn't timeout-related. Defaults to 5. - pub fn retryable_error_cost(mut self, retryable_error_cost: u32) -> Self { - self.retryable_error_cost = Some(retryable_error_cost); - self - } - - /// Build this builder. Unset fields will be set to their default values. - pub fn build(self) -> Standard { - let starting_tokens = self - .starting_tokens - .unwrap_or(STANDARD_INITIAL_RETRY_TOKENS); - let max_tokens = self.max_tokens.unwrap_or(starting_tokens); - let timeout_error_cost = self - .timeout_error_cost - .unwrap_or(STANDARD_TIMEOUT_ERROR_RETRY_COST); - let retryable_error_cost = self - .retryable_error_cost - .unwrap_or(STANDARD_RETRYABLE_ERROR_RETRY_COST); - - Standard { - inner: Arc::new(Semaphore::new(starting_tokens)), - max_tokens, - timeout_error_cost, - retryable_error_cost, - } - } -} - -impl TokenBucket for Standard { - type Token = token::Standard; - - fn try_acquire( - &self, - previous_response_kind: Option, - ) -> Result { - let number_of_tokens_to_acquire = match previous_response_kind { - None => { - // Return an empty token because the quota layer lifecycle expects a for each - // request even though the standard token bucket only requires tokens for retry - // attempts. - return Ok(token::Standard::empty()); - } - - Some(retry_kind) => match retry_kind { - RetryKind::Unnecessary => { - unreachable!("BUG: asked for a token to retry a successful request") - } - RetryKind::UnretryableFailure => { - unreachable!("BUG: asked for a token to retry an un-retryable request") - } - RetryKind::Explicit(_) => self.retryable_error_cost, - RetryKind::Error(error_kind) => match error_kind { - ErrorKind::ThrottlingError | ErrorKind::TransientError => { - self.timeout_error_cost - } - ErrorKind::ServerError => self.retryable_error_cost, - ErrorKind::ClientError => unreachable!( - "BUG: asked for a token to retry a request that failed due to user error" - ), - _ => unreachable!( - "A new variant '{:?}' was added to ErrorKind, please handle it", - error_kind - ), - }, - _ => unreachable!( - "A new variant '{:?}' was added to RetryKind, please handle it", - retry_kind - ), - }, - }; - - match self - .inner - .clone() - .try_acquire_many_owned(number_of_tokens_to_acquire) - { - Ok(permit) => Ok(token::Standard::new(permit)), - Err(TryAcquireError::NoPermits) => Err(RateLimitingError::no_tokens()), - Err(other) => Err(RateLimitingError::bug(other.to_string())), - } - } - - fn available(&self) -> usize { - self.inner.available_permits() - } - - fn refill(&self, tokens: usize) { - // Ensure the bucket doesn't overflow by limiting the amount of tokens to add, if necessary. - let amount_to_add = (self.available() + tokens).min(self.max_tokens) - self.available(); - if amount_to_add > 0 { - self.inner.add_permits(amount_to_add) - } - } -} - -#[cfg(test)] -mod test { - use super::{Token, TokenBucket}; - use super::{ - STANDARD_INITIAL_RETRY_TOKENS, STANDARD_RETRYABLE_ERROR_RETRY_COST, - STANDARD_TIMEOUT_ERROR_RETRY_COST, - }; - use aws_smithy_types::retry::{ErrorKind, RetryKind}; - - #[test] - fn bucket_works() { - let bucket = super::Standard::builder().build(); - assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS); - - let token = bucket - .try_acquire(Some(RetryKind::Error(ErrorKind::ServerError))) - .unwrap(); - assert_eq!( - bucket.available(), - STANDARD_INITIAL_RETRY_TOKENS - STANDARD_RETRYABLE_ERROR_RETRY_COST as usize - ); - Box::new(token).release(); - - let token = bucket - .try_acquire(Some(RetryKind::Error(ErrorKind::TransientError))) - .unwrap(); - assert_eq!( - bucket.available(), - STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize - ); - Box::new(token).forget(); - assert_eq!( - bucket.available(), - STANDARD_INITIAL_RETRY_TOKENS - STANDARD_TIMEOUT_ERROR_RETRY_COST as usize - ); - - bucket.refill(STANDARD_TIMEOUT_ERROR_RETRY_COST as usize); - assert_eq!(bucket.available(), STANDARD_INITIAL_RETRY_TOKENS); - } -} diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index 8246e2e9b..9bbbf10b2 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -139,7 +139,7 @@ async fn try_op( { let request_serializer = cfg.request_serializer(); let input = ctx.take_input().expect("input set at this point"); - let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg)); + let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg).map_err(OrchestratorError::other)); ctx.set_request(request); } @@ -171,10 +171,10 @@ async fn try_op( // No, this request shouldn't be sent Ok(ShouldAttempt::No) => { let err: BoxError = "the retry strategy indicates that an initial request shouldn't be made, but it didn't specify why".into(); - halt!([ctx] => err); + halt!([ctx] => OrchestratorError::other(err)); } // No, we shouldn't make a request because... - Err(err) => halt!([ctx] => err), + Err(err) => halt!([ctx] => OrchestratorError::other(err)), Ok(ShouldAttempt::YesAfterDelay(_)) => { unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.") } @@ -183,7 +183,7 @@ async fn try_op( // Save a request checkpoint before we make the request. This will allow us to "rewind" // the request in the case of retry attempts. ctx.save_checkpoint(); - for i in 0usize.. { + for i in 1usize.. { debug!("beginning attempt #{i}"); // Break from the loop if we can't rewind the request's state. This will always succeed the // first time, but will fail on subsequent iterations if the request body wasn't retryable. @@ -201,19 +201,21 @@ async fn try_op( } .maybe_timeout_with_config(attempt_timeout_config) .await - .map_err(OrchestratorError::other); + .map_err(|err| OrchestratorError::timeout(err.into_source().unwrap())); // We continue when encountering a timeout error. The retry classifier will decide what to do with it. continue_on_err!([ctx] => maybe_timeout); let retry_strategy = cfg.retry_strategy(); + // If we got a retry strategy from the bag, ask it what to do. // If no strategy was set, we won't retry. - let should_attempt = halt_on_err!( - [ctx] => retry_strategy - .map(|rs| rs.should_attempt_retry(ctx, cfg)) - .unwrap_or(Ok(ShouldAttempt::No) - )); + let should_attempt = match retry_strategy { + Some(retry_strategy) => halt_on_err!( + [ctx] => retry_strategy.should_attempt_retry(ctx, cfg).map_err(OrchestratorError::other) + ), + None => ShouldAttempt::No, + }; match should_attempt { // Yes, let's retry the request ShouldAttempt::Yes => continue, @@ -241,11 +243,11 @@ async fn try_attempt( stop_point: StopPoint, ) { halt_on_err!([ctx] => interceptors.read_before_attempt(ctx, cfg)); - halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg)); + halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg).map_err(OrchestratorError::other)); halt_on_err!([ctx] => interceptors.modify_before_signing(ctx, cfg)); halt_on_err!([ctx] => interceptors.read_before_signing(ctx, cfg)); - halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await); + halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await.map_err(OrchestratorError::other)); halt_on_err!([ctx] => interceptors.read_after_signing(ctx, cfg)); halt_on_err!([ctx] => interceptors.modify_before_transmit(ctx, cfg)); @@ -261,7 +263,12 @@ async fn try_attempt( ctx.enter_transmit_phase(); let call_result = halt_on_err!([ctx] => { let request = ctx.take_request().expect("set during serialization"); - cfg.connection().call(request).await + cfg.connection().call(request).await.map_err(|err| { + match err.downcast() { + Ok(connector_error) => OrchestratorError::connector(*connector_error), + Err(box_err) => OrchestratorError::other(box_err) + } + }) }); ctx.set_response(call_result); ctx.enter_before_deserialization_phase(); @@ -279,7 +286,7 @@ async fn try_attempt( None => read_body(response) .instrument(debug_span!("read_body")) .await - .map_err(OrchestratorError::other) + .map_err(OrchestratorError::response) .and_then(|_| response_deserializer.deserialize_nonstreaming(response)), } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/classifier.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/classifier.rs index fd9754cff..0139e9b31 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/classifier.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/classifier.rs @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_http::result::SdkError; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; +use aws_smithy_runtime_api::client::orchestrator::OrchestratorError; use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason}; use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind}; use std::borrow::Cow; @@ -76,17 +76,17 @@ where Ok(_) => return None, Err(err) => err, }; - // Check that the error is an operation error - let error = error.as_operation_error()?; - // Downcast the error - let error = error.downcast_ref::>()?; + match error { - SdkError::TimeoutError(_) => Some(RetryReason::Error(ErrorKind::TransientError)), - SdkError::ResponseError { .. } => Some(RetryReason::Error(ErrorKind::TransientError)), - SdkError::DispatchFailure(err) if (err.is_timeout() || err.is_io()) => { + OrchestratorError::Response { .. } | OrchestratorError::Timeout { .. } => { Some(RetryReason::Error(ErrorKind::TransientError)) } - SdkError::DispatchFailure(err) => err.is_other().map(RetryReason::Error), + OrchestratorError::Connector { err } if err.is_timeout() || err.is_io() => { + Some(RetryReason::Error(ErrorKind::TransientError)) + } + OrchestratorError::Connector { err } if err.is_other().is_some() => { + err.is_other().map(RetryReason::Error) + } _ => None, } } @@ -152,8 +152,6 @@ mod test { HttpStatusCodeClassifier, ModeledAsRetryableClassifier, }; use aws_smithy_http::body::SdkBody; - use aws_smithy_http::operation; - use aws_smithy_http::result::SdkError; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::OrchestratorError; use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason}; @@ -242,11 +240,10 @@ mod test { #[test] fn classify_response_error() { let policy = SmithyErrorClassifier::::new(); - let test_response = http::Response::new("OK").map(SdkBody::from); - let err: SdkError = - SdkError::response_error(UnmodeledError, operation::Response::new(test_response)); let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter")); - ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err)))); + ctx.set_output_or_error(Err(OrchestratorError::response( + "I am a response error".into(), + ))); assert_eq!( policy.classify_retry(&ctx), Some(RetryReason::Error(ErrorKind::TransientError)), @@ -256,9 +253,10 @@ mod test { #[test] fn test_timeout_error() { let policy = SmithyErrorClassifier::::new(); - let err: SdkError = SdkError::timeout_error("blah"); let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter")); - ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err)))); + ctx.set_output_or_error(Err(OrchestratorError::timeout( + "I am a timeout error".into(), + ))); assert_eq!( policy.classify_retry(&ctx), Some(RetryReason::Error(ErrorKind::TransientError)), diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy.rs index 3805834b1..c046a4d9f 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy.rs @@ -6,7 +6,7 @@ #[cfg(feature = "test-util")] mod fixed_delay; mod never; -mod standard; +pub(crate) mod standard; #[cfg(feature = "test-util")] pub use fixed_delay::FixedDelayRetryStrategy; diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs index 689331316..468d0957b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs @@ -3,6 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ +use crate::client::retries::strategy::standard::ReleaseResult::{ + APermitWasReleased, NoPermitWasReleased, +}; +use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors}; use aws_smithy_runtime_api::client::request_attempts::RequestAttempts; @@ -11,16 +15,21 @@ use aws_smithy_runtime_api::client::retries::{ }; use aws_smithy_types::config_bag::ConfigBag; use aws_smithy_types::retry::RetryConfig; +use std::sync::Mutex; use std::time::Duration; +use tokio::sync::OwnedSemaphorePermit; +// The initial attempt, plus three retries. const DEFAULT_MAX_ATTEMPTS: usize = 4; #[derive(Debug)] pub struct StandardRetryStrategy { - max_attempts: usize, + // Retry settings + base: fn() -> f64, initial_backoff: Duration, + max_attempts: usize, max_backoff: Duration, - base: fn() -> f64, + retry_permit: Mutex>, } impl StandardRetryStrategy { @@ -45,6 +54,36 @@ impl StandardRetryStrategy { self.initial_backoff = initial_backoff; self } + + pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self { + self.max_backoff = max_backoff; + self + } + + fn release_retry_permit(&self) -> ReleaseResult { + let mut retry_permit = self.retry_permit.lock().unwrap(); + match retry_permit.take() { + Some(p) => { + drop(p); + APermitWasReleased + } + None => NoPermitWasReleased, + } + } + + fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) { + let mut old_retry_permit = self.retry_permit.lock().unwrap(); + if let Some(p) = old_retry_permit.replace(new_retry_permit) { + // Whenever we set a new retry permit and it replaces the old one, we need to "forget" + // the old permit, removing it from the bucket forever. + p.forget() + } + } +} + +enum ReleaseResult { + APermitWasReleased, + NoPermitWasReleased, } impl Default for StandardRetryStrategy { @@ -55,13 +94,14 @@ impl Default for StandardRetryStrategy { // by default, use a random base for exponential backoff base: fastrand::f64, initial_backoff: Duration::from_secs(1), + retry_permit: Mutex::new(None), } } } impl RetryStrategy for StandardRetryStrategy { - // TODO(token-bucket) add support for optional cross-request token bucket fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result { + // The standard token bucket is only ever considered for retry requests. Ok(ShouldAttempt::Yes) } @@ -74,18 +114,31 @@ impl RetryStrategy for StandardRetryStrategy { let output_or_error = ctx.output_or_error().expect( "This must never be called without reaching the point where the result exists.", ); + let token_bucket = cfg.get::(); if output_or_error.is_ok() { tracing::debug!("request succeeded, no retry necessary"); + if let Some(tb) = token_bucket { + // If this retry strategy is holding any permits, release them back to the bucket. + if let NoPermitWasReleased = self.release_retry_permit() { + // In the event that there was no retry permit to release, we generate new + // permits from nothing. We do this to make up for permits we had to "forget". + // Otherwise, repeated retries would empty the bucket and nothing could fill it + // back up again. + tb.regenerate_a_token(); + } + } + return Ok(ShouldAttempt::No); } // Check if we're out of attempts - let request_attempts: &RequestAttempts = cfg - .get() - .expect("at least one request attempt is made before any retry is attempted"); - if request_attempts.attempts() >= self.max_attempts { + let request_attempts = cfg + .get::() + .expect("at least one request attempt is made before any retry is attempted") + .attempts(); + if request_attempts >= self.max_attempts { tracing::trace!( - attempts = request_attempts.attempts(), + attempts = request_attempts, max_attempts = self.max_attempts, "not retrying because we are out of attempts" ); @@ -95,9 +148,24 @@ impl RetryStrategy for StandardRetryStrategy { // Run the classifiers against the context to determine if we should retry let retry_classifiers = cfg.retry_classifiers(); let retry_reason = retry_classifiers.classify_retry(ctx); + + // Calculate the appropriate backoff time. let backoff = match retry_reason { Some(RetryReason::Explicit(dur)) => dur, - Some(RetryReason::Error(_)) => { + Some(RetryReason::Error(kind)) => { + // If a token bucket was set, and the RetryReason IS NOT explicit, attempt to acquire a retry permit. + if let Some(tb) = token_bucket { + match tb.acquire(&kind) { + Some(permit) => self.set_retry_permit(permit), + None => { + tracing::debug!( + "attempt #{request_attempts} failed with {kind:?}; However, no retry permits are available, so no retry will be attempted.", + ); + return Ok(ShouldAttempt::No); + } + } + }; + let backoff = calculate_exponential_backoff( // Generate a random base multiplier to create jitter (self.base)(), @@ -105,16 +173,14 @@ impl RetryStrategy for StandardRetryStrategy { self.initial_backoff.as_secs_f64(), // `self.local.attempts` tracks number of requests made including the initial request // The initial attempt shouldn't count towards backoff calculations so we subtract it - (request_attempts.attempts() - 1) as u32, + (request_attempts - 1) as u32, ); Duration::from_secs_f64(backoff).min(self.max_backoff) } - Some(_) => { - unreachable!("RetryReason is non-exhaustive. Therefore, we need to cover this unreachable case.") - } + Some(_) => unreachable!("RetryReason is non-exhaustive"), None => { tracing::trace!( - attempts = request_attempts.attempts(), + attempts = request_attempts, max_attempts = self.max_attempts, "encountered unretryable error" ); @@ -123,8 +189,7 @@ impl RetryStrategy for StandardRetryStrategy { }; tracing::debug!( - "attempt {} failed with {:?}; retrying after {:?}", - request_attempts.attempts(), + "attempt #{request_attempts} failed with {:?}; retrying after {:?}", retry_reason.expect("the match statement above ensures this is not None"), backoff ); @@ -139,16 +204,23 @@ fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts #[cfg(test)] mod tests { - use super::{ShouldAttempt, StandardRetryStrategy}; + use super::{calculate_exponential_backoff, ShouldAttempt, StandardRetryStrategy}; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError}; use aws_smithy_runtime_api::client::request_attempts::RequestAttempts; - use aws_smithy_runtime_api::client::retries::{AlwaysRetry, RetryClassifiers, RetryStrategy}; + use aws_smithy_runtime_api::client::retries::{ + AlwaysRetry, ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy, + }; use aws_smithy_types::config_bag::{ConfigBag, Layer}; - use aws_smithy_types::retry::ErrorKind; + use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind}; use aws_smithy_types::type_erasure::TypeErasedBox; + use std::fmt; + use std::sync::Mutex; use std::time::Duration; + #[cfg(feature = "test-util")] + use crate::client::runtime_plugin::standard_token_bucket::StandardTokenBucket; + #[test] fn no_retry_necessary_for_ok_result() { let cfg = ConfigBag::base(); @@ -221,4 +293,351 @@ mod tests { .expect("method is infallible for this use"); assert_eq!(ShouldAttempt::No, actual); } + + #[derive(Debug)] + struct ServerError; + impl fmt::Display for ServerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OperationError") + } + } + + impl std::error::Error for ServerError {} + + impl ProvideErrorKind for ServerError { + fn retryable_error_kind(&self) -> Option { + Some(ErrorKind::ServerError) + } + + fn code(&self) -> Option<&str> { + None + } + } + + #[derive(Debug)] + struct PresetReasonRetryClassifier { + retry_reasons: Mutex>, + } + + #[cfg(feature = "test-util")] + impl PresetReasonRetryClassifier { + fn new(mut retry_reasons: Vec) -> Self { + // We'll pop the retry_reasons in reverse order so we reverse the list to fix that. + retry_reasons.reverse(); + Self { + retry_reasons: Mutex::new(retry_reasons), + } + } + } + + impl ClassifyRetry for PresetReasonRetryClassifier { + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { + if ctx.output_or_error().map(|it| it.is_ok()).unwrap_or(false) { + return None; + } + + let mut retry_reasons = self.retry_reasons.lock().unwrap(); + if retry_reasons.len() == 1 { + Some(retry_reasons.first().unwrap().clone()) + } else { + retry_reasons.pop() + } + } + + fn name(&self) -> &'static str { + "Always returns a preset retry reason" + } + } + + #[cfg(feature = "test-util")] + fn setup_test(retry_reasons: Vec) -> (ConfigBag, InterceptorContext) { + let mut cfg = ConfigBag::base(); + cfg.interceptor_state().set_retry_classifiers( + RetryClassifiers::new() + .with_classifier(PresetReasonRetryClassifier::new(retry_reasons)), + ); + let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter()); + // This type doesn't matter b/c the classifier will just return whatever we tell it to. + ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter"))); + + (cfg, ctx) + } + + #[cfg(feature = "test-util")] + #[test] + fn eventual_success() { + let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(5); + cfg.interceptor_state().put(StandardTokenBucket::default()); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 495); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(2)); + assert_eq!(token_bucket.available_permits(), 490); + + ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter())); + + cfg.interceptor_state().put(RequestAttempts::new(3)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + assert_eq!(token_bucket.available_permits(), 495); + } + + #[cfg(feature = "test-util")] + #[test] + fn no_more_attempts() { + let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(3); + cfg.interceptor_state().put(StandardTokenBucket::default()); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 495); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(2)); + assert_eq!(token_bucket.available_permits(), 490); + + cfg.interceptor_state().put(RequestAttempts::new(3)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + assert_eq!(token_bucket.available_permits(), 490); + } + + #[cfg(feature = "test-util")] + #[test] + fn no_quota() { + let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(5); + cfg.interceptor_state().put(StandardTokenBucket::new(5)); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 0); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + assert_eq!(token_bucket.available_permits(), 0); + } + + #[cfg(feature = "test-util")] + #[test] + fn quota_replenishes_on_success() { + let (mut cfg, mut ctx) = setup_test(vec![ + RetryReason::Error(ErrorKind::TransientError), + RetryReason::Explicit(Duration::from_secs(1)), + ]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(5); + cfg.interceptor_state().put(StandardTokenBucket::new(100)); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 90); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 90); + + ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter())); + + cfg.interceptor_state().put(RequestAttempts::new(3)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + + assert_eq!(token_bucket.available_permits(), 100); + } + + #[cfg(feature = "test-util")] + #[test] + fn quota_replenishes_on_first_try_success() { + const PERMIT_COUNT: usize = 20; + let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(usize::MAX); + cfg.interceptor_state() + .put(StandardTokenBucket::new(PERMIT_COUNT)); + let token_bucket = cfg.get::().unwrap().clone(); + + let mut attempt = 1; + + // Drain all available permits with failed attempts + while token_bucket.available_permits() > 0 { + // Draining should complete in 2 attempts + if attempt > 2 { + panic!("This test should have completed by now (drain)"); + } + + cfg.interceptor_state().put(RequestAttempts::new(attempt)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_))); + attempt += 1; + } + + // Forget the permit so that we can only refill by "success on first try". + let permit = strategy.retry_permit.lock().unwrap().take().unwrap(); + permit.forget(); + + ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter())); + + // Replenish permits until we get back to `PERMIT_COUNT` + while token_bucket.available_permits() < PERMIT_COUNT { + if attempt > 23 { + panic!("This test should have completed by now (fillup)"); + } + + cfg.interceptor_state().put(RequestAttempts::new(attempt)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + attempt += 1; + } + + assert_eq!(attempt, 23); + assert_eq!(token_bucket.available_permits(), PERMIT_COUNT); + } + + #[cfg(feature = "test-util")] + #[test] + fn backoff_timing() { + let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(5); + cfg.interceptor_state().put(StandardTokenBucket::default()); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 495); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(2)); + assert_eq!(token_bucket.available_permits(), 490); + + cfg.interceptor_state().put(RequestAttempts::new(3)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(4)); + assert_eq!(token_bucket.available_permits(), 485); + + cfg.interceptor_state().put(RequestAttempts::new(4)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(8)); + assert_eq!(token_bucket.available_permits(), 480); + + cfg.interceptor_state().put(RequestAttempts::new(5)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + assert_eq!(token_bucket.available_permits(), 480); + } + + #[cfg(feature = "test-util")] + #[test] + fn max_backoff_time() { + let (mut cfg, ctx) = setup_test(vec![RetryReason::Error(ErrorKind::ServerError)]); + let strategy = StandardRetryStrategy::default() + .with_base(|| 1.0) + .with_max_attempts(5) + .with_initial_backoff(Duration::from_secs(1)) + .with_max_backoff(Duration::from_secs(3)); + cfg.interceptor_state().put(StandardTokenBucket::default()); + let token_bucket = cfg.get::().unwrap().clone(); + + cfg.interceptor_state().put(RequestAttempts::new(1)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(1)); + assert_eq!(token_bucket.available_permits(), 495); + + cfg.interceptor_state().put(RequestAttempts::new(2)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(2)); + assert_eq!(token_bucket.available_permits(), 490); + + cfg.interceptor_state().put(RequestAttempts::new(3)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(3)); + assert_eq!(token_bucket.available_permits(), 485); + + cfg.interceptor_state().put(RequestAttempts::new(4)); + let should_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + let dur = should_retry.expect_delay(); + assert_eq!(dur, Duration::from_secs(3)); + assert_eq!(token_bucket.available_permits(), 480); + + cfg.interceptor_state().put(RequestAttempts::new(5)); + let no_retry = strategy.should_attempt_retry(&ctx, &cfg).unwrap(); + assert_eq!(no_retry, ShouldAttempt::No); + assert_eq!(token_bucket.available_permits(), 480); + } + + #[test] + fn calculate_exponential_backoff_where_initial_backoff_is_one() { + let initial_backoff = 1.0; + + for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() { + let actual_backoff = + calculate_exponential_backoff(1.0, initial_backoff, attempt as u32); + assert_eq!(expected_backoff, actual_backoff); + } + } + + #[test] + fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() { + let initial_backoff = 3.0; + + for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() { + let actual_backoff = + calculate_exponential_backoff(1.0, initial_backoff, attempt as u32); + assert_eq!(expected_backoff, actual_backoff); + } + } + + #[test] + fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() { + let initial_backoff = 0.03; + + for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() { + let actual_backoff = + calculate_exponential_backoff(1.0, initial_backoff, attempt as u32); + assert_eq!(expected_backoff, actual_backoff); + } + } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin.rs b/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin.rs index 784de5cd7..a127873bd 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin.rs @@ -5,3 +5,5 @@ #[cfg(feature = "anonymous-auth")] pub mod anonymous_auth; + +pub mod standard_token_bucket; diff --git a/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin/standard_token_bucket.rs b/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin/standard_token_bucket.rs new file mode 100644 index 000000000..548d63307 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/src/client/runtime_plugin/standard_token_bucket.rs @@ -0,0 +1,100 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; +use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, StoreReplace}; +use aws_smithy_types::retry::ErrorKind; +use std::sync::Arc; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tracing::trace; + +/// A [RuntimePlugin] to provide a standard token bucket, usable by the +/// [`StandardRetryStrategy`](crate::client::retries::strategy::standard::StandardRetryStrategy). +#[non_exhaustive] +#[derive(Debug, Default)] +pub struct StandardTokenBucketRuntimePlugin { + token_bucket: StandardTokenBucket, +} + +impl StandardTokenBucketRuntimePlugin { + pub fn new(initial_tokens: usize) -> Self { + Self { + token_bucket: StandardTokenBucket::new(initial_tokens), + } + } +} + +impl RuntimePlugin for StandardTokenBucketRuntimePlugin { + fn config(&self) -> Option { + let mut cfg = Layer::new("standard token bucket"); + cfg.store_put(self.token_bucket.clone()); + + Some(cfg.freeze()) + } +} + +const DEFAULT_CAPACITY: usize = 500; +const RETRY_COST: u32 = 5; +const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2; +const PERMIT_REGENERATION_AMOUNT: usize = 1; + +#[derive(Clone, Debug)] +pub(crate) struct StandardTokenBucket { + semaphore: Arc, + max_permits: usize, + timeout_retry_cost: u32, + retry_cost: u32, +} + +impl Storable for StandardTokenBucket { + type Storer = StoreReplace; +} + +impl Default for StandardTokenBucket { + fn default() -> Self { + Self { + semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)), + max_permits: DEFAULT_CAPACITY, + timeout_retry_cost: RETRY_TIMEOUT_COST, + retry_cost: RETRY_COST, + } + } +} + +impl StandardTokenBucket { + pub(crate) fn new(initial_quota: usize) -> Self { + Self { + semaphore: Arc::new(Semaphore::new(initial_quota)), + max_permits: initial_quota, + retry_cost: RETRY_COST, + timeout_retry_cost: RETRY_TIMEOUT_COST, + } + } + + pub(crate) fn acquire(&self, err: &ErrorKind) -> Option { + let retry_cost = if err == &ErrorKind::TransientError { + self.timeout_retry_cost + } else { + self.retry_cost + }; + + self.semaphore + .clone() + .try_acquire_many_owned(retry_cost) + .ok() + } + + pub(crate) fn regenerate_a_token(&self) { + if self.semaphore.available_permits() < (self.max_permits) { + trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket"); + self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT) + } + } + + #[cfg(all(test, feature = "test-util"))] + pub(crate) fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } +} diff --git a/rust-runtime/aws-smithy-types/src/retry.rs b/rust-runtime/aws-smithy-types/src/retry.rs index b96ababf0..9602d6135 100644 --- a/rust-runtime/aws-smithy-types/src/retry.rs +++ b/rust-runtime/aws-smithy-types/src/retry.rs @@ -143,6 +143,7 @@ pub struct RetryConfigBuilder { mode: Option, max_attempts: Option, initial_backoff: Option, + max_backoff: Option, reconnect_mode: Option, } @@ -212,6 +213,18 @@ impl RetryConfigBuilder { self } + /// Set the max_backoff duration. This duration should be non-zero. + pub fn set_max_backoff(&mut self, max_backoff: Option) -> &mut Self { + self.max_backoff = max_backoff; + self + } + + /// Set the max_backoff duration. This duration should be non-zero. + pub fn max_backoff(mut self, max_backoff: Duration) -> Self { + self.set_max_backoff(Some(max_backoff)); + self + } + /// Merge two builders together. Values from `other` will only be used as a fallback for values /// from `self` Useful for merging configs from different sources together when you want to /// handle "precedence" per value instead of at the config level @@ -233,6 +246,7 @@ impl RetryConfigBuilder { mode: self.mode.or(other.mode), max_attempts: self.max_attempts.or(other.max_attempts), initial_backoff: self.initial_backoff.or(other.initial_backoff), + max_backoff: self.max_backoff.or(other.max_backoff), reconnect_mode: self.reconnect_mode.or(other.reconnect_mode), } } @@ -248,6 +262,7 @@ impl RetryConfigBuilder { reconnect_mode: self .reconnect_mode .unwrap_or(ReconnectMode::ReconnectOnTransientError), + max_backoff: self.max_backoff.unwrap_or_else(|| Duration::from_secs(20)), } } } @@ -259,6 +274,7 @@ pub struct RetryConfig { mode: RetryMode, max_attempts: u32, initial_backoff: Duration, + max_backoff: Duration, reconnect_mode: ReconnectMode, } @@ -286,6 +302,7 @@ impl RetryConfig { max_attempts: 3, initial_backoff: Duration::from_secs(1), reconnect_mode: ReconnectMode::ReconnectOnTransientError, + max_backoff: Duration::from_secs(20), } } -- GitLab