Commit 7f3b2043 authored by ysaito1001's avatar ysaito1001
Browse files

Make `TokenBucket` and `ClientRateLimiter` buildable

parent 436fd290
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -15,10 +15,11 @@ mod token_bucket;
use aws_smithy_types::config_bag::{Storable, StoreReplace};
use std::fmt;

pub use client_rate_limiter::ClientRateLimiter;
pub use token_bucket::TokenBucket;
pub use client_rate_limiter::{
    ClientRateLimiter, ClientRateLimiterBuilder, ClientRateLimiterPartition,
};
pub use token_bucket::{TokenBucket, TokenBucketBuilder};

pub use client_rate_limiter::ClientRateLimiterPartition;
use std::borrow::Cow;

/// Represents the retry partition, e.g. an endpoint, a region
+75 −65
Original line number Diff line number Diff line
@@ -93,8 +93,9 @@ impl ClientRateLimiter {
            .build()
    }

    fn builder() -> Builder {
        Builder::new()
    /// Creates a new `ClientRateLimiterBuilder`
    pub fn builder() -> ClientRateLimiterBuilder {
        ClientRateLimiterBuilder::new()
    }

    pub(crate) fn acquire_permission_to_send_a_request(
@@ -237,8 +238,9 @@ fn cubic_throttle(rate_to_use: f64) -> f64 {
    rate_to_use * BETA
}

/// Builder for `ClientRateLimiter`.
#[derive(Clone, Debug, Default)]
struct Builder {
pub struct ClientRateLimiterBuilder {
    ///The rate at which token are replenished.
    token_refill_rate: Option<f64>,
    ///The maximum capacity allowed in the token bucket.
@@ -261,95 +263,112 @@ struct Builder {
    time_of_last_throttle: Option<f64>,
}

impl Builder {
    fn new() -> Self {
        Builder::default()
impl ClientRateLimiterBuilder {
    /// Create a new `ClientRateLimiterBuilder`.
    pub fn new() -> Self {
        ClientRateLimiterBuilder::default()
    }
    /// The rate at which token are replenished.
    fn set_token_refill_rate(&mut self, token_refill_rate: Option<f64>) -> &mut Self {
        self.token_refill_rate = token_refill_rate;
    pub fn token_refill_rate(mut self, token_refill_rate: f64) -> Self {
        self.set_token_refill_rate(Some(token_refill_rate));
        self
    }
    /// The rate at which token are replenished.
    fn token_refill_rate(mut self, token_refill_rate: f64) -> Self {
        self.token_refill_rate = Some(token_refill_rate);
    pub fn set_token_refill_rate(&mut self, token_refill_rate: Option<f64>) -> &mut Self {
        self.token_refill_rate = token_refill_rate;
        self
    }
    /// The maximum capacity allowed in the token bucket.
    fn set_maximum_bucket_capacity(&mut self, maximum_bucket_capacity: Option<f64>) -> &mut Self {
        self.maximum_bucket_capacity = maximum_bucket_capacity;
    pub fn maximum_bucket_capacity(mut self, maximum_bucket_capacity: f64) -> Self {
        self.set_maximum_bucket_capacity(Some(maximum_bucket_capacity));
        self
    }
    /// The maximum capacity allowed in the token bucket.
    fn maximum_bucket_capacity(mut self, maximum_bucket_capacity: f64) -> Self {
        self.maximum_bucket_capacity = Some(maximum_bucket_capacity);
    pub fn set_maximum_bucket_capacity(
        &mut self,
        maximum_bucket_capacity: Option<f64>,
    ) -> &mut Self {
        self.maximum_bucket_capacity = maximum_bucket_capacity;
        self
    }
    /// The current capacity of the token bucket. The minimum this can be is 1.0
    fn set_current_bucket_capacity(&mut self, current_bucket_capacity: Option<f64>) -> &mut Self {
        self.current_bucket_capacity = current_bucket_capacity;
    pub fn current_bucket_capacity(mut self, current_bucket_capacity: f64) -> Self {
        self.set_current_bucket_capacity(Some(current_bucket_capacity));
        self
    }
    /// The current capacity of the token bucket. The minimum this can be is 1.0
    fn current_bucket_capacity(mut self, current_bucket_capacity: f64) -> Self {
        self.current_bucket_capacity = Some(current_bucket_capacity);
    pub fn set_current_bucket_capacity(
        &mut self,
        current_bucket_capacity: Option<f64>,
    ) -> &mut Self {
        self.current_bucket_capacity = current_bucket_capacity;
        self
    }
    /// The last time the token bucket was refilled.
    fn set_time_of_last_refill(&mut self, time_of_last_refill: Option<f64>) -> &mut Self {
        self.time_of_last_refill = time_of_last_refill;
    pub fn time_of_last_refill(mut self, time_of_last_refill: f64) -> Self {
        self.set_time_of_last_refill(Some(time_of_last_refill));
        self
    }
    /// The last time the token bucket was refilled.
    fn time_of_last_refill(mut self, time_of_last_refill: f64) -> Self {
        self.time_of_last_refill = Some(time_of_last_refill);
    pub fn set_time_of_last_refill(&mut self, time_of_last_refill: Option<f64>) -> &mut Self {
        self.time_of_last_refill = time_of_last_refill;
        self
    }
    /// The smoothed rate which tokens are being retrieved.
    fn set_tokens_retrieved_per_second(
    pub fn tokens_retrieved_per_second(mut self, tokens_retrieved_per_second: f64) -> Self {
        self.set_tokens_retrieved_per_second(Some(tokens_retrieved_per_second));
        self
    }
    /// The smoothed rate which tokens are being retrieved.
    pub fn set_tokens_retrieved_per_second(
        &mut self,
        tokens_retrieved_per_second: Option<f64>,
    ) -> &mut Self {
        self.tokens_retrieved_per_second = tokens_retrieved_per_second;
        self
    }
    ///The smoothed rate which tokens are being retrieved.
    fn tokens_retrieved_per_second(mut self, tokens_retrieved_per_second: f64) -> Self {
        self.tokens_retrieved_per_second = Some(tokens_retrieved_per_second);
    /// The last half second time bucket used.
    pub fn previous_time_bucket(mut self, previous_time_bucket: f64) -> Self {
        self.set_previous_time_bucket(Some(previous_time_bucket));
        self
    }
    /// The last half second time bucket used.
    fn set_previous_time_bucket(&mut self, previous_time_bucket: Option<f64>) -> &mut Self {
    pub fn set_previous_time_bucket(&mut self, previous_time_bucket: Option<f64>) -> &mut Self {
        self.previous_time_bucket = previous_time_bucket;
        self
    }
    ///The last half second time bucket used.
    fn previous_time_bucket(mut self, previous_time_bucket: f64) -> Self {
        self.previous_time_bucket = Some(previous_time_bucket);
    /// The number of requests seen within the current time bucket.
    pub fn request_count(mut self, request_count: u64) -> Self {
        self.set_request_count(Some(request_count));
        self
    }
    /// The number of requests seen within the current time bucket.
    fn set_request_count(&mut self, request_count: Option<u64>) -> &mut Self {
    pub fn set_request_count(&mut self, request_count: Option<u64>) -> &mut Self {
        self.request_count = request_count;
        self
    }
    ///The number of requests seen within the current time bucket.
    fn request_count(mut self, request_count: u64) -> Self {
        self.request_count = Some(request_count);
    /// Boolean indicating if the token bucket is enabled. The token bucket is initially disabled. When a throttling error is encountered it is enabled.
    pub fn enable_throttling(mut self, enable_throttling: bool) -> Self {
        self.set_enable_throttling(Some(enable_throttling));
        self
    }
    /// Boolean indicating if the token bucket is enabled. The token bucket is initially disabled. When a throttling error is encountered it is enabled.
    fn set_enable_throttling(&mut self, enable_throttling: Option<bool>) -> &mut Self {
    pub fn set_enable_throttling(&mut self, enable_throttling: Option<bool>) -> &mut Self {
        self.enable_throttling = enable_throttling;
        self
    }
    ///Boolean indicating if the token bucket is enabled. The token bucket is initially disabled. When a throttling error is encountered it is enabled.
    fn enable_throttling(mut self, enable_throttling: bool) -> Self {
        self.enable_throttling = Some(enable_throttling);
    /// The maximum rate when the client was last throttled.
    pub fn tokens_retrieved_per_second_at_time_of_last_throttle(
        mut self,
        tokens_retrieved_per_second_at_time_of_last_throttle: f64,
    ) -> Self {
        self.set_tokens_retrieved_per_second_at_time_of_last_throttle(Some(
            tokens_retrieved_per_second_at_time_of_last_throttle,
        ));
        self
    }
    /// The maximum rate when the client was last throttled.
    fn set_tokens_retrieved_per_second_at_time_of_last_throttle(
    pub fn set_tokens_retrieved_per_second_at_time_of_last_throttle(
        &mut self,
        tokens_retrieved_per_second_at_time_of_last_throttle: Option<f64>,
    ) -> &mut Self {
@@ -357,30 +376,21 @@ impl Builder {
            tokens_retrieved_per_second_at_time_of_last_throttle;
        self
    }
    ///The maximum rate when the client was last throttled.
    fn tokens_retrieved_per_second_at_time_of_last_throttle(
        mut self,
        tokens_retrieved_per_second_at_time_of_last_throttle: f64,
    ) -> Self {
        self.tokens_retrieved_per_second_at_time_of_last_throttle =
            Some(tokens_retrieved_per_second_at_time_of_last_throttle);
        self
    }
    /// The last time when the client was throttled.
    fn set_time_of_last_throttle(&mut self, time_of_last_throttle: Option<f64>) -> &mut Self {
        self.time_of_last_throttle = time_of_last_throttle;
    pub fn time_of_last_throttle(mut self, time_of_last_throttle: f64) -> Self {
        self.set_time_of_last_throttle(Some(time_of_last_throttle));
        self
    }
    /// The last time when the client was throttled.
    fn time_of_last_throttle(mut self, time_of_last_throttle: f64) -> Self {
        self.time_of_last_throttle = Some(time_of_last_throttle);
    pub fn set_time_of_last_throttle(&mut self, time_of_last_throttle: Option<f64>) -> &mut Self {
        self.time_of_last_throttle = time_of_last_throttle;
        self
    }

    fn build(self) -> ClientRateLimiter {
    /// Build the ClientRateLimiter.
    pub fn build(self) -> ClientRateLimiter {
        ClientRateLimiter {
            inner: Arc::new(Mutex::new(Inner {
                fill_rate: self.token_refill_rate.unwrap_or_default(),
                fill_rate: self.token_refill_rate.unwrap_or(MIN_FILL_RATE),
                max_capacity: self.maximum_bucket_capacity.unwrap_or(f64::MAX),
                current_capacity: self.current_bucket_capacity.unwrap_or_default(),
                last_timestamp: self.time_of_last_refill,
+126 −4
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ pub struct TokenBucket {
    max_permits: usize,
    timeout_retry_cost: u32,
    retry_cost: u32,
    regeneration_amount: usize,
}

impl Storable for TokenBucket {
@@ -34,6 +35,7 @@ impl Default for TokenBucket {
            max_permits: DEFAULT_CAPACITY,
            timeout_retry_cost: RETRY_TIMEOUT_COST,
            retry_cost: RETRY_COST,
            regeneration_amount: PERMIT_REGENERATION_AMOUNT,
        }
    }
}
@@ -44,11 +46,26 @@ impl TokenBucket {
        Self {
            semaphore: Arc::new(Semaphore::new(initial_quota)),
            max_permits: initial_quota,
            retry_cost: RETRY_COST,
            timeout_retry_cost: RETRY_TIMEOUT_COST,
            ..Default::default()
        }
    }

    /// A token bucket with unlimited capacity that allows retries at no cost.
    pub fn unlimited() -> Self {
        Self {
            semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
            max_permits: Semaphore::MAX_PERMITS,
            timeout_retry_cost: 0,
            retry_cost: 0,
            regeneration_amount: 0,
        }
    }

    /// Creates a builder for constructing a `TokenBucket`.
    pub fn builder() -> TokenBucketBuilder {
        TokenBucketBuilder::default()
    }

    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
        let retry_cost = if err == &ErrorKind::TransientError {
            self.timeout_retry_cost
@@ -64,8 +81,11 @@ impl TokenBucket {

    pub(crate) fn regenerate_a_token(&self) {
        if self.semaphore.available_permits() < (self.max_permits) {
            trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
            self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
            trace!(
                "adding {regeneration_amount} back into the bucket",
                regeneration_amount = self.regeneration_amount
            );
            self.semaphore.add_permits(self.regeneration_amount)
        }
    }

@@ -74,3 +94,105 @@ impl TokenBucket {
        self.semaphore.available_permits()
    }
}

/// Builder for constructing a `TokenBucket`.
#[derive(Clone, Debug, Default)]
pub struct TokenBucketBuilder {
    capacity: Option<usize>,
    retry_cost: Option<u32>,
    timeout_retry_cost: Option<u32>,
    regeneration_amount: Option<usize>,
}

impl TokenBucketBuilder {
    /// Creates a new `TokenBucketBuilder` with default values.
    pub fn new() -> Self {
        Self::default()
    }

    /// Sets the specified initial capacity for the builder.
    pub fn capacity(mut self, capacity: usize) -> Self {
        self.capacity = Some(capacity);
        self
    }

    /// Sets the specified retry cost for the builder.
    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
        self.retry_cost = Some(retry_cost);
        self
    }

    /// Sets the specified timeout retry cost for the builder.
    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
        self.timeout_retry_cost = Some(timeout_retry_cost);
        self
    }

    /// Sets the specified regeneration amount for the builder.
    pub fn regeneration_amount(mut self, regeneration_amount: usize) -> Self {
        self.regeneration_amount = Some(regeneration_amount);
        self
    }

    /// Builds a `TokenBucket`.
    pub fn build(self) -> TokenBucket {
        TokenBucket {
            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
            retry_cost: self.retry_cost.unwrap_or(RETRY_COST),
            timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST),
            regeneration_amount: self
                .regeneration_amount
                .unwrap_or(PERMIT_REGENERATION_AMOUNT),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_unlimited_token_bucket() {
        let bucket = TokenBucket::unlimited();

        // Should always acquire permits regardless of error type
        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());

        // Should have maximum capacity
        assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS);

        // Should have zero retry costs
        assert_eq!(bucket.retry_cost, 0);
        assert_eq!(bucket.timeout_retry_cost, 0);

        // The loop count is arbitrary; should obtain permits without limit
        let mut permits = Vec::new();
        for _ in 0..100 {
            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
            assert!(permit.is_some());
            permits.push(permit);
        }
    }

    #[test]
    fn test_bounded_permits_exhaustion() {
        let bucket = TokenBucket::new(10);
        let mut permits = Vec::new();

        for _ in 0..100 {
            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
            if let Some(p) = permit {
                permits.push(p);
            } else {
                break;
            }
        }

        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits

        // Verify next acquisition fails
        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
    }
}