Unverified Commit fe0a3a3d authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Make `TokenBucket` and `ClientRateLimiter` configurable (#4263)

parents d9c5aeba 1e9af909
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
---
applies_to:
- client
- aws-sdk-rust
authors:
- ysaito1001
references:
- smithy-rs#4263
breaking: false
new_feature: false
bug_fix: false
---
Make [`TokenBucket`](https://docs.rs/aws-smithy-runtime/latest/aws_smithy_runtime/client/retries/struct.TokenBucket.html) and [`ClientRateLimiter`](https://docs.rs/aws-smithy-runtime/latest/aws_smithy_runtime/client/retries/struct.ClientRateLimiter.html) configurable through [`RetryPartition`](https://docs.rs/aws-smithy-runtime/latest/aws_smithy_runtime/client/retries/struct.RetryPartition.html).
+85 −2
Original line number Diff line number Diff line
@@ -51,7 +51,8 @@ class RetryPartitionTest {
                            .expect("success");

                        let log_contents = logs_rx.contents();
                        assert!(log_contents.contains("token bucket for RetryPartition { name: \"dontcare-us-west-2\" } added to config bag"));
                        let expected = r##"token bucket for RetryPartition { inner: Default("dontcare-us-west-2") } added to config bag"##;
                        assert!(log_contents.contains(expected));

                        """,
                        *codegenScope,
@@ -80,7 +81,8 @@ class RetryPartitionTest {
                            .expect("success");

                        let log_contents = logs_rx.contents();
                        assert!(log_contents.contains("token bucket for RetryPartition { name: \"user-partition\" } added to config bag"));
                        let expected = r##"token bucket for RetryPartition { inner: Default("user-partition") } added to config bag"##;
                        assert!(log_contents.contains(expected));

                        """,
                        *codegenScope,
@@ -90,4 +92,85 @@ class RetryPartitionTest {
            }
        }
    }

    // This test doesn't need to be in "sdk-codegen" but since "default retry partition" test was initially here,
    // it is added to this file for consistency.
    @Test
    fun `custom retry partition`() {
        awsSdkIntegrationTest(
            SdkCodegenIntegrationTest.model,
        ) { ctx, crate ->
            val codegenScope =
                arrayOf(
                    "BeforeTransmitInterceptorContextRef" to RuntimeType.beforeTransmitInterceptorContextRef(ctx.runtimeConfig),
                    "BoxError" to RuntimeType.boxError(ctx.runtimeConfig),
                    "capture_test_logs" to
                        CargoDependency.smithyRuntimeTestUtil(ctx.runtimeConfig).toType()
                            .resolve("test_util::capture_test_logs::capture_test_logs"),
                    "capture_request" to RuntimeType.captureRequest(ctx.runtimeConfig),
                    "ConfigBag" to RuntimeType.configBag(ctx.runtimeConfig),
                    "Intercept" to RuntimeType.intercept(ctx.runtimeConfig),
                    "RetryConfig" to RuntimeType.smithyTypes(ctx.runtimeConfig).resolve("retry::RetryConfig"),
                    "RetryPartition" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::RetryPartition"),
                    "RuntimeComponents" to RuntimeType.runtimeComponents(ctx.runtimeConfig),
                    "TokenBucket" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::TokenBucket"),
                )
            crate.integrationTest("custom_retry_partition") {
                tokioTest("test_custom_token_bucket") {
                    val moduleName = ctx.moduleUseName()
                    rustTemplate(
                        """
                        use std::sync::{Arc, atomic::{AtomicU32, Ordering}};
                        use $moduleName::{Client, Config};

                        ##[derive(Clone, Debug, Default)]
                        struct TestInterceptor {
                            called: Arc<AtomicU32>,
                        }
                        impl #{Intercept} for TestInterceptor {
                            fn name(&self) -> &'static str {
                                "TestInterceptor"
                            }
                            fn read_before_attempt(
                                &self,
                                _context: &#{BeforeTransmitInterceptorContextRef}<'_>,
                                _runtime_components: &#{RuntimeComponents},
                                cfg: &mut #{ConfigBag},
                            ) -> Result<(), #{BoxError}> {
                                self.called.fetch_add(1, Ordering::Relaxed);
                                let token_bucket = cfg.load::<#{TokenBucket}>().unwrap();
                                let expected = format!("permits: {}", tokio::sync::Semaphore::MAX_PERMITS);
                                assert!(
                                    format!("{token_bucket:?}").contains(&expected),
                                    "Expected debug output to contain `{expected}`, but got: {token_bucket:?}"
                                );
                                Ok(())
                            }
                        }

                        let (http_client, _) = #{capture_request}(None);
                        let test_interceptor = TestInterceptor::default();
                        let client_config = Config::builder()
                            .interceptor(test_interceptor.clone())
                            .retry_partition(#{RetryPartition}::custom("test")
                                .token_bucket(#{TokenBucket}::unlimited())
                                .build()
                            )
                            .http_client(http_client)
                            .build();

                        let client = Client::from_conf(client_config);
                        let _ = client.some_operation().send().await;

                        assert!(
                            test_interceptor.called.load(Ordering::Relaxed) == 1,
                            "the interceptor should have been called"
                        );
                        """,
                        *codegenScope,
                    )
                }
            }
        }
    }
}
+2 −0
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@ allowed_external_types = [

    "aws_smithy_runtime::client::identity::cache::IdentityCache",
    "aws_smithy_runtime::client::retries::RetryPartition",
    "aws_smithy_runtime::client::retries::client_rate_limiter::ClientRateLimiter",
    "aws_smithy_runtime::client::retries::token_bucket::TokenBucket",

    "aws_runtime::invocation_id::SharedInvocationIdGenerator",
    "aws_runtime::invocation_id::InvocationIdGenerator",
+49 −6
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.util.sdkId

class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() {
    private val runtimeConfig = codegenContext.runtimeConfig
@@ -23,6 +24,8 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf
    private val timeoutModule = RuntimeType.smithyTypes(runtimeConfig).resolve("timeout")
    private val retries = RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries")
    private val moduleUseName = codegenContext.moduleUseName()
    private val sdkId = codegenContext.serviceShape.sdkId()
    private val defaultRetryPartition = sdkId.lowercase().replace(" ", "")
    private val codegenScope =
        arrayOf(
            *preludeScope,
@@ -266,8 +269,50 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf
                    rustTemplate(
                        """
                        /// Set the partition for retry-related state. When clients share a retry partition, they will
                        /// also share things like token buckets and client rate limiters. By default, all clients
                        /// for the same service will share a partition.
                        /// also share components such as token buckets and client rate limiters.
                        /// See the [`RetryPartition`](#{RetryPartition}) documentation for more details.
                        ///
                        /// ## Default Behavior
                        ///
                        /// When no retry partition is explicitly set, the SDK automatically creates a default retry partition named `$defaultRetryPartition`
                        /// (or `$defaultRetryPartition-<region>` if a region is configured).
                        /// All $sdkId clients without an explicit retry partition will share this default partition.
                        ///
                        /// ## Notes
                        ///
                        /// - This is an advanced setting — most users won't need to modify it.
                        /// - A configured client rate limiter has no effect unless [`RetryConfig::adaptive`](#{RetryConfig}::adaptive) is used.
                        ///
                        /// ## Examples
                        ///
                        /// Creating a custom retry partition with a token bucket:
                        /// ```no_run
                        /// use $moduleUseName::config::Config;
                        /// use $moduleUseName::config::retry::{RetryPartition, TokenBucket};
                        ///
                        /// let token_bucket = TokenBucket::new(10);
                        /// let config = Config::builder()
                        ///     .retry_partition(RetryPartition::custom("custom")
                        ///         .token_bucket(token_bucket)
                        ///         .build()
                        ///     )
                        ///     .build();
                        /// ```
                        ///
                        /// Configuring a client rate limiter with adaptive retry mode:
                        /// ```no_run
                        /// use $moduleUseName::config::Config;
                        /// use $moduleUseName::config::retry::{ClientRateLimiter, RetryConfig, RetryPartition};
                        ///
                        /// let client_rate_limiter = ClientRateLimiter::new(10.0);
                        /// let config = Config::builder()
                        ///     .retry_partition(RetryPartition::custom("custom")
                        ///         .client_rate_limiter(client_rate_limiter)
                        ///         .build()
                        ///     )
                        ///     .retry_config(RetryConfig::adaptive())
                        ///     .build();
                        /// ```
                        pub fn retry_partition(mut self, retry_partition: #{RetryPartition}) -> Self {
                            self.set_retry_partition(Some(retry_partition));
                            self
@@ -278,9 +323,7 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf

                    rustTemplate(
                        """
                        /// Set the partition for retry-related state. When clients share a retry partition, they will
                        /// also share things like token buckets and client rate limiters. By default, all clients
                        /// for the same service will share a partition.
                        /// Like [`Self::retry_partition`], but takes a mutable reference to the builder and an optional `RetryPartition`
                        pub fn set_retry_partition(&mut self, retry_partition: #{Option}<#{RetryPartition}>) -> &mut Self {
                            retry_partition.map(|r| self.config.store_put(r));
                            self
@@ -327,7 +370,7 @@ class ResiliencyReExportCustomization(codegenContext: ClientCodegenContext) {
            )

            rustTemplate(
                "pub use #{types_retry}::RetryPartition;",
                "pub use #{types_retry}::{ClientRateLimiter, RetryPartition, TokenBucket};",
                "types_retry" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries"),
            )
        }
+1 −1
Original line number Diff line number Diff line
@@ -620,7 +620,7 @@ dependencies = [

[[package]]
name = "aws-smithy-runtime"
version = "1.9.0"
version = "1.9.1"
dependencies = [
 "approx",
 "aws-smithy-async",
Loading