Unverified Commit 9959d9b9 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Introduce a `Standard` enum for connections (#237)

* Enum Conn POC

* Make alias public

* Add docs, Debug impl, format

* Add comment delete dead code

* A couple of missing cargo-formats

* Cleanup, address CR feedback

* Derive clone instead of implementing it manually
parent d364e694
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ smithy-types = { path = "../../../rust-runtime/smithy-types" }
smithy-http-tower = { path = "../../../rust-runtime/smithy-http-tower" }
fastrand = "1.4.0"
tokio = { version = "1", features = ["time"]}
pin-project = "1"
tracing = "0.1.25"

[dev-dependencies]
+154 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::BoxError;
use http::Request;
use hyper::client::{HttpConnector, ResponseFuture};
use hyper::Response;
use hyper_tls::HttpsConnector;
use smithy_http::body::SdkBody;
use std::future::{Future, Ready};
use std::pin::Pin;
use std::task::{Context, Poll};
use tower::Service;

#[derive(Clone)]
pub struct Standard(Connector);

impl Standard {
    /// An https connection
    pub fn https() -> Self {
        let https = HttpsConnector::new();
        Self(Connector::Https(hyper::Client::builder().build::<_, SdkBody>(https)))
    }

    /// A connection based on the provided `Box<dyn HttpService>`
    ///
    /// Generally, `https()` should be used instead. This constructor is intended to support
    /// using things like [`TestConnection`](crate::test_connection::TestConnection) or alternative
    /// http implementations.
    pub fn new(connector: Box<dyn HttpService>) -> Self {
        Self(Connector::Dyn(connector))
    }
}

/// An Http connection type for most use cases
///
/// This supports three options:
/// 1. HTTPS
/// 2. A `TestConnection`
/// 3. Any implementation of the `HttpService` trait
///
/// This is designed to be used with [`aws_hyper::Client`](crate::Client) as a connector.
#[derive(Clone)]
enum Connector {
    /// An Https Connection
    ///
    /// This is the correct connection for use cases talking to real AWS services.
    Https(hyper::Client<HttpsConnector<HttpConnector>, SdkBody>),

    /// A generic escape hatch
    ///
    /// This enables using any implementation of the HttpService trait. This allows using a totally
    /// separate HTTP stack or your own custom `TestConnection`.
    Dyn(Box<dyn HttpService>),
}

impl Clone for Box<dyn HttpService> {
    fn clone(&self) -> Self {
        self.clone_box()
    }
}

pub trait HttpService: Send {
    /// Return whether this service is ready to accept a request
    ///
    /// See [`Service::poll_ready`](tower::Service::poll_ready)
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>>;

    /// Call this service and return a response
    ///
    /// See [`Service::call`](tower::Service::call)
    fn call(
        &mut self,
        req: http::Request<SdkBody>,
    ) -> Pin<Box<dyn Future<Output = Result<http::Response<hyper::Body>, BoxError>> + Send>>;

    /// Return a Boxed-clone of this service
    ///
    /// `aws_hyper::Client` will clone the inner service for each request so this should be a cheap
    /// clone operation.
    fn clone_box(&self) -> Box<dyn HttpService>;
}

/// Reverse implementation: If you have a correctly shaped tower service, it _is_ an `HttpService`
///
/// This is to facilitate ease of use for people using `Standard::Dyn`
impl<S> HttpService for S
where
    S: Service<http::Request<SdkBody>, Response = http::Response<hyper::Body>>
        + Send
        + Clone
        + 'static,
    S::Error: Into<BoxError> + Send + Sync + 'static,
    S::Future: Send + 'static,
{
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
        Service::poll_ready(self, cx).map_err(|err| err.into())
    }

    fn call(
        &mut self,
        req: Request<SdkBody>,
    ) -> Pin<Box<dyn Future<Output = Result<Response<hyper::Body>, BoxError>> + Send>> {
        let fut = Service::call(self, req);
        let fut = async move { fut.await.map_err(|err| err.into()) };
        Box::pin(fut)
    }

    fn clone_box(&self) -> Box<dyn HttpService> {
        Box::new(self.clone())
    }
}

impl tower::Service<http::Request<SdkBody>> for Standard {
    type Response = http::Response<hyper::Body>;
    type Error = BoxError;
    type Future = StandardFuture;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        match &mut self.0 {
            Connector::Https(https) => Service::poll_ready(https, cx).map_err(|err| err.into()),
            Connector::Dyn(conn) => conn.poll_ready(cx),
        }
    }

    fn call(&mut self, req: http::Request<SdkBody>) -> Self::Future {
        match &mut self.0 {
            Connector::Https(https) => StandardFuture::Https(Service::call(https, req)),
            Connector::Dyn(conn) => StandardFuture::Dyn(conn.call(req)),
        }
    }
}

/// Future returned by `Standard` when used as a tower::Service
#[pin_project::pin_project(project = FutProj)]
pub enum StandardFuture {
    Https(#[pin] ResponseFuture),
    TestConn(#[pin] Ready<Result<http::Response<hyper::Body>, BoxError>>),
    Dyn(#[pin] Pin<Box<dyn Future<Output = Result<http::Response<hyper::Body>, BoxError>> + Send>>),
}

impl Future for StandardFuture {
    type Output = Result<http::Response<hyper::Body>, BoxError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.project() {
            FutProj::TestConn(ready_fut) => ready_fut.poll(cx),
            FutProj::Https(fut) => fut.poll(cx).map_err(|err| err.into()),
            FutProj::Dyn(dyn_fut) => dyn_fut.poll(cx),
        }
    }
}
+30 −14
Original line number Diff line number Diff line
pub mod conn;
mod retry;
pub mod test_connection;

pub use retry::RetryConfig;

use crate::conn::Standard;
use crate::retry::RetryHandlerFactory;
use aws_endpoint::AwsEndpointStage;
use aws_http::user_agent::UserAgentStage;
use aws_sig_auth::middleware::SigV4SigningStage;
use aws_sig_auth::signer::SigV4Signer;
use hyper::client::HttpConnector;
use hyper::Client as HyperClient;
use hyper_tls::HttpsConnector;
use smithy_http::body::SdkBody;
use smithy_http::operation::Operation;
use smithy_http::response::ParseHttpResponse;
@@ -19,9 +19,12 @@ use smithy_http_tower::map_request::MapRequestLayer;
use smithy_http_tower::parse_response::ParseResponseLayer;
use smithy_types::retry::ProvideErrorKind;
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Formatter};
use tower::{Service, ServiceBuilder, ServiceExt};

type BoxError = Box<dyn Error + Send + Sync>;
pub type StandardClient = Client<conn::Standard>;

pub type SdkError<E> = smithy_http::result::SdkError<E, hyper::Body>;
pub type SdkSuccess<T> = smithy_http::result::SdkSuccess<T, hyper::Body>;
@@ -41,10 +44,17 @@ pub type SdkSuccess<T> = smithy_http::result::SdkSuccess<T, hyper::Body>;
///    S::Error: Into<BoxError> + Send + Sync + 'static,
///    S::Future: Send + 'static,
/// ```

pub struct Client<S> {
    inner: S,
    retry_strategy: RetryHandlerFactory,
    retry_handler: RetryHandlerFactory,
}

impl<S> Debug for Client<S> {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        let mut formatter = f.debug_struct("Client");
        formatter.field("retry_handler", &self.retry_handler);
        formatter.finish()
    }
}

impl<S> Client<S> {
@@ -52,24 +62,22 @@ impl<S> Client<S> {
    pub fn new(connector: S) -> Self {
        Client {
            inner: connector,
            retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
            retry_handler: RetryHandlerFactory::new(RetryConfig::default()),
        }
    }

    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
        self.retry_strategy.with_config(retry_config);
        self.retry_handler.with_config(retry_config);
        self
    }
}

impl Client<hyper::Client<HttpsConnector<HttpConnector>, SdkBody>> {
impl Client<Standard> {
    /// Construct an `https` based client
    pub fn https() -> Self {
        let https = HttpsConnector::new();
        let client = HyperClient::builder().build::<_, SdkBody>(https);
    pub fn https() -> StandardClient {
        Client {
            inner: client,
            retry_strategy: RetryHandlerFactory::new(RetryConfig::default()),
            inner: Standard::https(),
            retry_handler: RetryHandlerFactory::new(RetryConfig::default()),
        }
    }
}
@@ -115,7 +123,7 @@ where
        let inner = self.inner.clone();
        let mut svc = ServiceBuilder::new()
            // Create a new request-scoped policy
            .retry(self.retry_strategy.new_handler())
            .retry(self.retry_handler.new_handler())
            .layer(ParseResponseLayer::<O, Retry>::new())
            .layer(endpoint_resolver)
            .layer(signer)
@@ -135,4 +143,12 @@ mod tests {
    fn construct_default_client() {
        let _ = Client::https();
    }

    #[test]
    fn client_debug_includes_retry_info() {
        let client = Client::https();
        let s = format!("{:?}", client);
        assert!(s.contains("RetryConfig"));
        assert!(s.contains("quota_available"));
    }
}
+3 −2
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ use tracing::Instrument;
/// Without specific use cases, users should generally rely on the default values set by `[RetryConfig::default]`(RetryConfig::default).`
///
/// Currently these fields are private and no setters provided. As needed, this configuration will become user-modifiable in the future..
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct RetryConfig {
    initial_retry_tokens: usize,
    retry_cost: usize,
@@ -87,6 +87,7 @@ const RETRY_COST: usize = 5;
/// `CrossRequestRetryState`
/// Its main functionality is via `new_handler` which creates a `RetryHandler` to manage the retry for
/// an individual request.
#[derive(Debug)]
pub struct RetryHandlerFactory {
    config: RetryConfig,
    shared_state: CrossRequestRetryState,
@@ -132,7 +133,7 @@ impl RequestLocalRetryState {
struct RetryPartition(Cow<'static, str>); */

/// Shared state between multiple requests to the same client.
#[derive(Clone)]
#[derive(Clone, Debug)]
struct CrossRequestRetryState {
    quota_available: Arc<Mutex<usize>>,
}
+11 −1
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ impl ValidateRequest {
/// - Response to requests with a preloaded series of responses
/// - Record requests for future examination
///
/// The generic parameter `B` is the type of the response body.
/// For more complex use cases, see [Tower Test](https://docs.rs/tower-test/0.4.0/tower_test/)
/// Usage example:
/// ```rust
@@ -62,12 +63,21 @@ impl ValidateRequest {
/// let conn = TestConnection::new(events);
/// let client = aws_hyper::Client::new(conn);
/// ```
#[derive(Clone)]
pub struct TestConnection<B> {
    data: Arc<Mutex<ConnectVec<B>>>,
    requests: Arc<Mutex<Vec<ValidateRequest>>>,
}

// Need a clone impl that ignores `B`
impl<B> Clone for TestConnection<B> {
    fn clone(&self) -> Self {
        TestConnection {
            data: self.data.clone(),
            requests: self.requests.clone(),
        }
    }
}

impl<B> TestConnection<B> {
    pub fn new(mut data: ConnectVec<B>) -> Self {
        data.reverse();
Loading