Unverified Commit 5cbfbf58 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support (& default to) Rustls (#311)

* Add support (& default to) Rustls

* Add example of switching to native tls

* Delete dead code

* Update conn.rs
parent fc0c6ea9
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -10,11 +10,14 @@ license = "Apache-2.0"
[features]
test-util = ["protocol-test-helpers"]
default = ["test-util"]
native-tls = ["hyper-tls"]
rustls = ["hyper-rustls"]

[dependencies]
hyper = { version = "0.14.2", features = ["client", "http1", "http2", "tcp", "runtime"] }
tower = { version = "0.4.6", features = ["util", "retry"] }
hyper-tls = "0.5.0"
hyper-tls = { version ="0.5.0", optional = true }
hyper-rustls = { version = "0.22.1", optional = true, features = ["rustls-native-certs"] }
aws-auth = { path = "../aws-auth" }
aws-sig-auth = { path = "../aws-sig-auth" }
aws-endpoint = { path = "../aws-endpoint" }
+50 −9
Original line number Diff line number Diff line
@@ -5,9 +5,8 @@

use crate::BoxError;
use http::Request;
use hyper::client::{HttpConnector, ResponseFuture};
use hyper::client::ResponseFuture;
use hyper::Response;
use hyper_tls::HttpsConnector;
use smithy_http::body::SdkBody;
use std::future::{Future, Ready};
use std::pin::Pin;
@@ -19,11 +18,36 @@ pub struct Standard(Connector);

impl Standard {
    /// An https connection
    ///
    /// If the `rustls` feature is enabled, this will use `rustls`.
    /// If the ONLY the `native-tls` feature is enabled, this will use `native-tls`.
    /// If both features are enabled, this will use `rustls`
    #[cfg(any(feature = "native-tls", feature = "rustls"))]
    pub fn https() -> Self {
        let https = HttpsConnector::new();
        Self(Connector::Https(
            hyper::Client::builder().build::<_, SdkBody>(https),
        ))
        #[cfg(feature = "rustls")]
        {
            Self::rustls()
        }

        // If we are compiling this function & rustls is not enabled, then native-tls MUST be enabled
        #[cfg(not(feature = "rustls"))]
        {
            Self::native_tls()
        }
    }

    #[cfg(feature = "rustls")]
    pub fn rustls() -> Self {
        let https = hyper_rustls::HttpsConnector::with_native_roots();
        let client = hyper::Client::builder().build::<_, SdkBody>(https);
        Self(Connector::RustlsHttps(client))
    }

    #[cfg(feature = "native-tls")]
    pub fn native_tls() -> Self {
        let https = hyper_tls::HttpsConnector::new();
        let client = hyper::Client::builder().build::<_, SdkBody>(https);
        Self(Connector::NativeHttps(client))
    }

    /// A connection based on the provided `impl HttpService`
@@ -49,7 +73,14 @@ enum Connector {
    /// An Https Connection
    ///
    /// This is the correct connection for use cases talking to real AWS services.
    Https(hyper::Client<HttpsConnector<HttpConnector>, SdkBody>),
    #[cfg(feature = "native-tls")]
    NativeHttps(hyper::Client<hyper_tls::HttpsConnector<hyper::client::HttpConnector>, SdkBody>),

    /// An Https Connection
    ///
    /// This is the correct connection for use cases talking to real AWS services.
    #[cfg(feature = "rustls")]
    RustlsHttps(hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>, SdkBody>),

    /// A generic escape hatch
    ///
@@ -123,14 +154,24 @@ impl tower::Service<http::Request<SdkBody>> for Standard {

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        match &mut self.0 {
            Connector::Https(https) => Service::poll_ready(https, cx).map_err(|err| err.into()),
            #[cfg(feature = "native-tls")]
            Connector::NativeHttps(https) => {
                Service::poll_ready(https, cx).map_err(|err| err.into())
            }
            #[cfg(feature = "rustls")]
            Connector::RustlsHttps(https) => {
                Service::poll_ready(https, cx).map_err(|err| err.into())
            }
            Connector::Dyn(conn) => conn.poll_ready(cx),
        }
    }

    fn call(&mut self, req: http::Request<SdkBody>) -> Self::Future {
        match &mut self.0 {
            Connector::Https(https) => StandardFuture::Https(Service::call(https, req)),
            #[cfg(feature = "native-tls")]
            Connector::NativeHttps(https) => StandardFuture::Https(Service::call(https, req)),
            #[cfg(feature = "rustls")]
            Connector::RustlsHttps(https) => StandardFuture::Https(Service::call(https, req)),
            Connector::Dyn(conn) => StandardFuture::Dyn(conn.call(req)),
        }
    }
+7 −11
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ impl<S> Client<S> {

impl Client<Standard> {
    /// Construct an `https` based client
    #[cfg(any(feature = "native-tls", feature = "rustls"))]
    pub fn https() -> StandardClient {
        Client {
            inner: Standard::https(),
@@ -141,26 +142,21 @@ where

#[cfg(test)]
mod tests {
    use crate::Client;

    #[cfg(any(feature = "rustls", feature = "native-tls"))]
    #[test]
    fn construct_default_client() {
        let _ = Client::https();
        let c = crate::Client::https();
        fn is_send_sync<T: Send + Sync>(_c: T) {}
        is_send_sync(c);
    }

    #[cfg(any(feature = "rustls", feature = "native-tls"))]
    #[test]
    fn client_debug_includes_retry_info() {
        let client = Client::https();
        let client = crate::Client::https();
        let s = format!("{:?}", client);
        assert!(s.contains("RetryConfig"));
        assert!(s.contains("quota_available"));
    }

    fn is_send_sync<T: Send + Sync>(_: T) {}

    #[test]
    fn client_is_send_sync() {
        let c = Client::https();
        is_send_sync(c);
    }
}
+10 −1
Original line number Diff line number Diff line
@@ -45,7 +45,10 @@ class FluentClientDecorator : RustCodegenDecorator {
        rustCrate.withModule(RustModule("client", module)) { writer ->
            FluentClientGenerator(protocolConfig).render(writer)
        }
        rustCrate.addFeature(Feature("client", true, listOf(protocolConfig.runtimeConfig.awsHyper().name)))
        val awsHyper = protocolConfig.runtimeConfig.awsHyper().name
        rustCrate.addFeature(Feature("client", true, listOf(awsHyper)))
        rustCrate.addFeature(Feature("rustls", default = true, listOf("$awsHyper/rustls")))
        rustCrate.addFeature(Feature("native-tls", default = false, listOf("$awsHyper/native-tls")))
    }

    override fun libRsCustomizations(
@@ -87,10 +90,16 @@ class FluentClientGenerator(protocolConfig: ProtocolConfig) {
        writer.rustBlock("impl Client") {
            rustTemplate(
                """
                ##[cfg(any(feature = "rustls", feature = "native-tls"))]
                pub fn from_env() -> Self {
                    Self::from_conf_conn(crate::Config::builder().build(), #{aws_hyper}::conn::Standard::https())
                }

                ##[cfg(any(feature = "rustls", feature = "native-tls"))]
                pub fn from_conf(conf: crate::Config) -> Self {
                    Self::from_conf_conn(conf, #{aws_hyper}::conn::Standard::https())
                }

                pub fn from_conf_conn(conf: crate::Config, conn: #{aws_hyper}::conn::Standard) -> Self {
                    let client = #{aws_hyper}::Client::new(conn);
                    Self { handle: std::sync::Arc::new(Handle { conf, client })}
+4 −1
Original line number Diff line number Diff line
@@ -7,7 +7,10 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
dynamodb = { package = "aws-sdk-dynamodb", path = "../../build/aws-sdk/dynamodb", features = ["client"] }
dynamodb = { package = "aws-sdk-dynamodb", path = "../../build/aws-sdk/dynamodb" }
### To use native TLS:
# dynamodb = { package = "aws-sdk-dynamodb", path = "../../build/aws-sdk/dynamodb", default-features = false, features = ["native-tls"] }

tokio = { version = "1", features = ["full"] }

# used only for static endpoint configuration: