Unverified Commit 8f844c57 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add Retry support (#232)

* Add Retry support

* Fixup some broken tests

* Refactor, add docs, rename retry strategy

* Some more renames, some more docs
parent 6eaae060
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -55,6 +55,8 @@ jobs:
        java-version: ${{ env.java_version }}
    - name: test
      run: ./gradlew :codegen:test
    - name: aws tests
      run: ./gradlew :aws:sdk-codegen:test
  integration-tests:
    name: Codegen integration tests
    runs-on: ubuntu-latest
+32 −15
Original line number Diff line number Diff line
pub mod user_agent;

use smithy_http::result::SdkError;
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::time::Duration;
@@ -11,6 +13,7 @@ use std::time::Duration;
/// 3. The code is checked against a predetermined list of throttling errors & transient error codes
/// 4. The status code is checked against a predetermined list of status codes
#[non_exhaustive]
#[derive(Clone)]
pub struct AwsErrorRetryPolicy;

const TRANSIENT_ERROR_STATUS_CODES: [u16; 2] = [400, 408];
@@ -45,11 +48,16 @@ impl Default for AwsErrorRetryPolicy {
    }
}

impl ClassifyResponse for AwsErrorRetryPolicy {
    fn classify<E, B>(&self, err: E, response: &http::Response<B>) -> RetryKind
impl<T, E, B> ClassifyResponse<T, SdkError<E, B>> for AwsErrorRetryPolicy
where
    E: ProvideErrorKind,
{
    fn classify(&self, err: Result<&T, &SdkError<E, B>>) -> RetryKind {
        let (err, response) = match err {
            Ok(_) => return RetryKind::NotRetryable,
            Err(SdkError::ServiceError { err, raw }) => (err, raw),
            Err(_) => return RetryKind::NotRetryable,
        };
        if let Some(retry_after_delay) = response
            .headers()
            .get("x-amz-retry-after")
@@ -80,6 +88,7 @@ impl ClassifyResponse for AwsErrorRetryPolicy {
#[cfg(test)]
mod test {
    use crate::AwsErrorRetryPolicy;
    use smithy_http::result::{SdkError, SdkSuccess};
    use smithy_http::retry::ClassifyResponse;
    use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
    use std::time::Duration;
@@ -110,12 +119,16 @@ mod test {
        }
    }

    fn make_err<E, B>(err: E, raw: http::Response<B>) -> Result<SdkSuccess<(), B>, SdkError<E, B>> {
        Err(SdkError::ServiceError { err, raw })
    }

    #[test]
    fn not_an_error() {
        let policy = AwsErrorRetryPolicy::new();
        let test_response = http::Response::new("OK");
        assert_eq!(
            policy.classify(UnmodeledError, &test_response),
            policy.classify(make_err(UnmodeledError, test_response).as_ref()),
            RetryKind::NotRetryable
        );
    }
@@ -128,7 +141,7 @@ mod test {
            .body("error!")
            .unwrap();
        assert_eq!(
            policy.classify(UnmodeledError, &test_resp),
            policy.classify(make_err(UnmodeledError, test_resp).as_ref()),
            RetryKind::Error(ErrorKind::TransientError)
        );
    }
@@ -139,16 +152,20 @@ mod test {
        let policy = AwsErrorRetryPolicy::new();

        assert_eq!(
            policy.classify(CodedError { code: "Throttling" }, &test_response),
            policy.classify(make_err(CodedError { code: "Throttling" }, test_response).as_ref()),
            RetryKind::Error(ErrorKind::ThrottlingError)
        );

        let test_response = http::Response::new("OK");
        assert_eq!(
            policy.classify(
                make_err(
                    CodedError {
                        code: "RequestTimeout"
                    },
                &test_response,
                    test_response
                )
                .as_ref()
            ),
            RetryKind::Error(ErrorKind::TransientError)
        )
@@ -164,7 +181,7 @@ mod test {
        let test_response = http::Response::new("OK");
        let policy = AwsErrorRetryPolicy::new();
        assert_eq!(
            policy.classify(err, &test_response),
            policy.classify(make_err(err, test_response).as_ref()),
            RetryKind::Error(ErrorKind::ThrottlingError)
        );
    }
@@ -187,7 +204,7 @@ mod test {
        let policy = AwsErrorRetryPolicy::new();

        assert_eq!(
            policy.classify(ModeledRetries, &test_response),
            policy.classify(make_err(ModeledRetries, test_response).as_ref()),
            RetryKind::Error(ErrorKind::ClientError)
        );
    }
@@ -201,7 +218,7 @@ mod test {
            .unwrap();

        assert_eq!(
            policy.classify(UnmodeledError, &test_response),
            policy.classify(make_err(UnmodeledError, test_response).as_ref()),
            RetryKind::Explicit(Duration::from_millis(5000))
        );
    }
+5 −2
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ edition = "2018"

[dependencies]
hyper = { version = "0.14.2", features = ["client", "http1", "http2", "tcp", "runtime"] }
tower = { version = "0.4.6", features = ["util"] }
tower = { version = "0.4.6", features = ["util", "retry"] }
hyper-tls = "0.5.0"
aws-auth = { path = "../aws-auth" }
aws-sig-auth = { path = "../aws-sig-auth" }
@@ -18,9 +18,12 @@ http = "0.2.3"
bytes = "1"
http-body = "0.4.0"
smithy-http = { path = "../../../rust-runtime/smithy-http" }
smithy-types = { path = "../../../rust-runtime/smithy-types" }
smithy-http-tower = { path = "../../../rust-runtime/smithy-http-tower" }
fastrand = "1.4.0"
tokio = { version = "1", features = ["time"]}

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["full", "test-util"] }
tower-test = "0.4.0"
aws-types = { path = "../aws-types" }
+27 −6
Original line number Diff line number Diff line
mod retry;
pub mod test_connection;
pub use retry::RetryConfig;

use crate::retry::RetryHandlerFactory;
use aws_endpoint::AwsEndpointStage;
use aws_http::user_agent::UserAgentStage;
use aws_sig_auth::middleware::SigV4SigningStage;
@@ -10,9 +13,11 @@ use hyper_tls::HttpsConnector;
use smithy_http::body::SdkBody;
use smithy_http::operation::Operation;
use smithy_http::response::ParseHttpResponse;
use smithy_http::retry::ClassifyResponse;
use smithy_http_tower::dispatch::DispatchLayer;
use smithy_http_tower::map_request::MapRequestLayer;
use smithy_http_tower::parse_response::ParseResponseLayer;
use smithy_types::retry::ProvideErrorKind;
use std::error::Error;
use tower::{Service, ServiceBuilder, ServiceExt};

@@ -39,12 +44,21 @@ pub type SdkSuccess<T> = smithy_http::result::SdkSuccess<T, hyper::Body>;

pub struct Client<S> {
    inner: S,
    retry_strategy: RetryHandlerFactory,
}

impl<S> Client<S> {
    /// Construct a new `Client` with a custom connector
    pub fn new(connector: S) -> Self {
        Client { inner: connector }
        Client {
            inner: connector,
            retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
        }
    }

    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
        self.retry_strategy.with_config(retry_config);
        self
    }
}

@@ -53,7 +67,10 @@ impl Client<hyper::Client<HttpsConnector<HttpConnector>, SdkBody>> {
    pub fn https() -> Self {
        let https = HttpsConnector::new();
        let client = HyperClient::builder().build::<_, SdkBody>(https);
        Client { inner: client }
        Client {
            inner: client,
            retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
        }
    }
}

@@ -72,8 +89,9 @@ where
    /// access the raw response use `call_raw`.
    pub async fn call<O, T, E, Retry>(&self, input: Operation<O, Retry>) -> Result<T, SdkError<E>>
    where
        O: ParseHttpResponse<hyper::Body, Output = Result<T, E>> + Send + 'static,
        E: Error,
        O: ParseHttpResponse<hyper::Body, Output = Result<T, E>> + Send + Clone + 'static,
        E: Error + ProvideErrorKind,
        Retry: ClassifyResponse<SdkSuccess<T>, SdkError<E>>,
    {
        self.call_raw(input).await.map(|res| res.parsed)
    }
@@ -87,14 +105,17 @@ where
        input: Operation<O, Retry>,
    ) -> Result<SdkSuccess<R>, SdkError<E>>
    where
        O: ParseHttpResponse<hyper::Body, Output = Result<R, E>> + Send + 'static,
        E: Error,
        O: ParseHttpResponse<hyper::Body, Output = Result<R, E>> + Send + Clone + 'static,
        E: Error + ProvideErrorKind,
        Retry: ClassifyResponse<SdkSuccess<R>, SdkError<E>>,
    {
        let signer = MapRequestLayer::for_mapper(SigV4SigningStage::new(SigV4Signer::new()));
        let endpoint_resolver = MapRequestLayer::for_mapper(AwsEndpointStage);
        let user_agent = MapRequestLayer::for_mapper(UserAgentStage::new());
        let inner = self.inner.clone();
        let mut svc = ServiceBuilder::new()
            // Create a new request-scoped policy
            .retry(self.retry_strategy.new_handler())
            .layer(ParseResponseLayer::<O, Retry>::new())
            .layer(endpoint_resolver)
            .layer(signer)
+407 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

//! Retry support for aws-hyper
//!
//! The actual retry policy implementation will likely be replaced
//! with the CRT implementation once the bindings exist. This
//! implementation is intended to be _correct_ but not especially long lasting.
//!
//! Components:
//! - [`RetryHandlerFactory`](crate::retry::RetryHandlerFactory): Top level manager, intended
//! to be associated with a [`Client`](crate::Client). Its sole purpose in life is to create a RetryHandler
//! for individual requests.
//! - [`RetryHandler`](crate::retry::RetryHandler): A request-scoped retry policy,
//! backed by request-local state and shared state contained within [`RetryHandlerFactory`](crate::retry::RetryHandlerFactory)
//! - [`RetryConfig`](crate::retry::RetryConfig): Static configuration (max retries, max backoff etc.)

use crate::{SdkError, SdkSuccess};
use smithy_http::operation;
use smithy_http::operation::Operation;
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;

/// Retry Policy Configuration
///
/// Without specific use cases, users should generally rely on the default values set by `[RetryConfig::default]`(RetryConfig::default).`
///
/// Currently these fields are private and no setters provided. As needed, this configuration will become user-modifiable in the future..
#[derive(Clone)]
pub struct RetryConfig {
    initial_retry_tokens: usize,
    retry_cost: usize,
    no_retry_increment: usize,
    timeout_retry_cost: usize,
    max_retries: u32,
    max_backoff: Duration,
    base: fn() -> f64,
}

impl RetryConfig {
    /// Override `b` in the exponential backoff computation
    ///
    /// By default, `base` is a randomly generated value between 0 and 1. In tests, it can
    /// be helpful to override this:
    /// ```rust
    /// use aws_hyper::RetryConfig;
    /// let conf = RetryConfig::default().with_base(||1_f64);
    /// ```
    pub fn with_base(mut self, base: fn() -> f64) -> Self {
        self.base = base;
        self
    }
}

impl Default for RetryConfig {
    fn default() -> Self {
        Self {
            initial_retry_tokens: INITIAL_RETRY_TOKENS,
            retry_cost: RETRY_COST,
            no_retry_increment: 1,
            timeout_retry_cost: 10,
            max_retries: MAX_RETRIES,
            max_backoff: Duration::from_secs(20),
            // by default, use a random base for exponential backoff
            base: fastrand::f64,
        }
    }
}

const MAX_RETRIES: u32 = 3;
const INITIAL_RETRY_TOKENS: usize = 500;
const RETRY_COST: usize = 5;

/// Manage retries for a service
///
/// An implementation of the `standard` AWS retry strategy as specified in the SEP. A `Strategy` is scoped to a client.
/// For an individual request, call [`RetryHandlerFactory::new_handler()`](RetryHandlerFactory::new_handler)
///
/// In the future, adding support for the adaptive retry strategy will be added by adding a `TokenBucket` to
/// `CrossRequestRetryState`
/// Its main functionality is via `new_handler` which creates a `RetryHandler` to manage the retry for
/// an individual request.
pub struct RetryHandlerFactory {
    config: RetryConfig,
    shared_state: CrossRequestRetryState,
}

impl RetryHandlerFactory {
    pub fn new(config: RetryConfig) -> Self {
        Self {
            shared_state: CrossRequestRetryState::new(config.initial_retry_tokens),
            config,
        }
    }

    pub fn with_config(&mut self, config: RetryConfig) {
        self.config = config;
    }

    pub(crate) fn new_handler(&self) -> RetryHandler {
        RetryHandler {
            local: RequestLocalRetryState::new(),
            shared: self.shared_state.clone(),
            config: self.config.clone(),
        }
    }
}


#[derive(Default, Clone)]
struct RequestLocalRetryState {
    attempts: u32,
    last_quota_usage: Option<usize>,
}

impl RequestLocalRetryState {
    pub fn new() -> Self {
        Self::default()
    }
}

/* TODO in followup PR:
/// RetryPartition represents a scope for cross request retry state
///
/// For example, a retry partition could be the id of a service. This would give each service a separate retry budget.
struct RetryPartition(Cow<'static, str>); */

/// Shared state between multiple requests to the same client.
#[derive(Clone)]
struct CrossRequestRetryState {
    quota_available: Arc<Mutex<usize>>,
}

// clippy is upset that we didn't use AtomicUsize here, but doing so makes the code
// significantly more complicated for negligible benefit.
#[allow(clippy::mutex_atomic)]
impl CrossRequestRetryState {
    pub fn new(initial_quota: usize) -> Self {
        Self {
            quota_available:
            Arc::new(Mutex::new(initial_quota)),
        }
    }

    fn quota_release(&self, value: Option<usize>, config: &RetryConfig) {
        let mut quota = self.quota_available.lock().unwrap();
        *quota += value.unwrap_or(config.no_retry_increment);
    }

    /// Attempt to acquire retry quota for `ErrorKind`
    ///
    /// If quota is available, the amount of quota consumed is returned
    /// If no quota is available, `None` is returned.
    fn quota_acquire(&self, err: &ErrorKind, config: &RetryConfig) -> Option<usize> {
        let mut quota = self.quota_available.lock().unwrap();
        let retry_cost = if err == &ErrorKind::TransientError {
            config.timeout_retry_cost
        } else {
            config.retry_cost
        };
        if retry_cost > *quota {
            None
        } else {
            *quota -= retry_cost;
            Some(retry_cost)
        }
    }
}

/// RetryHandler
///
/// Implement retries for an individual request.
/// It is intended to be used as a [Tower Retry Policy](tower::retry::Policy) for use in tower-based
/// middleware stacks.
#[derive(Clone)]
pub(crate) struct RetryHandler {
    local: RequestLocalRetryState,
    shared: CrossRequestRetryState,
    config: RetryConfig,
}

#[cfg(test)]
impl RetryHandler {
    fn retry_quota(&self) -> usize {
        *self.shared.quota_available.lock().unwrap()
    }
}

impl RetryHandler {
    /// Determine the correct response given `retry_kind`
    ///
    /// If a retry is specified, this function returns `(next, backoff_duration)`
    /// If no retry is specified, this function returns None
    pub fn attempt_retry(&self, retry_kind: Result<(), ErrorKind>) -> Option<(Self, Duration)> {
        let quota_used = match retry_kind {
            Ok(_) => {
                self.shared
                    .quota_release(self.local.last_quota_usage, &self.config);
                return None;
            }
            Err(e) => {
                if self.local.attempts == self.config.max_retries - 1 {
                    return None;
                }
                self.shared.quota_acquire(&e, &self.config)?
            }
        };
        /*
        From the retry spec:
            b = random number within the range of: 0 <= b <= 1
            r = 2
            t_i = min(br^i, MAX_BACKOFF);
         */
        let r: i32 = 2;
        let b = (self.config.base)();
        let backoff = b * (r.pow(self.local.attempts) as f64);
        let backoff = Duration::from_secs_f64(backoff).min(self.config.max_backoff);
        let next = RetryHandler {
            local: RequestLocalRetryState {
                attempts: self.local.attempts + 1,
                last_quota_usage: Some(quota_used),
            },
            shared: self.shared.clone(),
            config: self.config.clone(),
        };

        Some((next, backoff))
    }
}

impl<Handler, R, T, E>
    tower::retry::Policy<operation::Operation<Handler, R>, SdkSuccess<T>, SdkError<E>>
    for RetryHandler
where
    E: ProvideErrorKind,
    Handler: Clone,
    R: ClassifyResponse<SdkSuccess<T>, SdkError<E>>,
{
    type Future = Pin<Box<dyn Future<Output = Self>>>;

    fn retry(
        &self,
        req: &Operation<Handler, R>,
        result: Result<&SdkSuccess<T>, &SdkError<E>>,
    ) -> Option<Self::Future> {
        let policy = req.retry_policy();
        let retry = policy.classify(result);
        let (next, fut) = match retry {
            RetryKind::Explicit(dur) => (self.clone(), dur),
            RetryKind::NotRetryable => return None,
            RetryKind::Error(err) => self.attempt_retry(Err(err))?,
            _ => return None,
        };
        let fut = async move {
            tokio::time::sleep(fut).await;
            next
        };
        Some(Box::pin(fut))
    }

    fn clone_request(&self, req: &Operation<Handler, R>) -> Option<Operation<Handler, R>> {
        req.try_clone()
    }
}

#[cfg(test)]
mod test {
    use crate::retry::{
        RetryConfig, RetryHandlerFactory,
    };
    use smithy_types::retry::ErrorKind;
    use std::time::Duration;

    fn test_config() -> RetryConfig {
        RetryConfig::default().with_base(|| 1_f64)
    }

    #[test]
    fn eventual_success() {
        let policy = RetryHandlerFactory::new(test_config()).new_handler();
        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(1));
        assert_eq!(policy.retry_quota(), 495);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(2));
        assert_eq!(policy.retry_quota(), 490);

        let no_retry = policy.attempt_retry(Ok(()));
        assert!(no_retry.is_none());
        assert_eq!(policy.retry_quota(), 495);
    }

    #[test]
    fn no_more_attempts() {
        let policy = RetryHandlerFactory::new(test_config()).new_handler();
        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(1));
        assert_eq!(policy.retry_quota(), 495);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(2));
        assert_eq!(policy.retry_quota(), 490);

        let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
        assert!(no_retry.is_none());
        assert_eq!(policy.retry_quota(), 490);
    }

    #[test]
    fn no_quota() {
        let mut conf = test_config();
        conf.initial_retry_tokens = 5;
        let policy = RetryHandlerFactory::new(conf).new_handler();
        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(1));
        assert_eq!(policy.retry_quota(), 0);
        let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
        assert!(no_retry.is_none());
        assert_eq!(policy.retry_quota(), 0);
    }

    #[test]
    fn backoff_timing() {
        let mut conf = test_config();
        conf.max_retries = 5;
        let policy = RetryHandlerFactory::new(conf).new_handler();
        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(1));
        assert_eq!(policy.retry_quota(), 495);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(2));
        assert_eq!(policy.retry_quota(), 490);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(4));
        assert_eq!(policy.retry_quota(), 485);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(8));
        assert_eq!(policy.retry_quota(), 480);

        let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
        assert!(no_retry.is_none());
        assert_eq!(policy.retry_quota(), 480);
    }

    #[test]
    fn max_backoff_time() {
        let mut conf = test_config();
        conf.max_retries = 5;
        conf.max_backoff = Duration::from_secs(3);
        let policy = RetryHandlerFactory::new(conf).new_handler();
        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(1));
        assert_eq!(policy.retry_quota(), 495);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(2));
        assert_eq!(policy.retry_quota(), 490);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(3));
        assert_eq!(policy.retry_quota(), 485);

        let (policy, dur) = policy
            .attempt_retry(Err(ErrorKind::ServerError))
            .expect("should retry");
        assert_eq!(dur, Duration::from_secs(3));
        assert_eq!(policy.retry_quota(), 480);

        let no_retry = policy.attempt_retry(Err(ErrorKind::ServerError));
        assert!(no_retry.is_none());
        assert_eq!(policy.retry_quota(), 480);
    }
}
Loading