Unverified Commit 9d2d0880 authored by Burak's avatar Burak Committed by GitHub
Browse files

Python: TLS Support for servers (#2002)

* Add `PyTlsConfig` struct

* Support TLS server

* Add TLS support to Pokemon service

* Make sure to create `tokio::net::TcpListener` in a Tokio context

* Fix doc link

* Add missing imports in tests

* Add `tls::Listener`

* Reload TLS config periodically

* Add context to `TODO`

* Return `&'static str` from `base_url()`

* Flatten `match` in `tls::Listener`

* Propogate listener errors but ignore handshake errors in `tls::Listener`

* Add tests to `tls::Listener`

* Add test to `tls::Listener` to make sure we are propogating listener errors

* Use `PathBuf` instead of plain `String`s for paths
parent b2b8e7e4
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -264,7 +264,7 @@ class PythonApplicationGenerator(
                    Ok(())
                }
                /// Main entrypoint: start the server on multiple workers.
                ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
                ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers, tls)")]
                pub fn run(
                    &mut self,
                    py: #{pyo3}::Python,
@@ -272,9 +272,10 @@ class PythonApplicationGenerator(
                    port: Option<i32>,
                    backlog: Option<i32>,
                    workers: Option<usize>,
                    tls: Option<#{SmithyPython}::tls::PyTlsConfig>,
                ) -> #{pyo3}::PyResult<()> {
                    use #{SmithyPython}::PyApp;
                    self.run_server(py, address, port, backlog, workers)
                    self.run_server(py, address, port, backlog, workers, tls)
                }
                /// Lambda entrypoint: start the server on Lambda.
                ##[cfg(feature = "aws-lambda")]
@@ -287,17 +288,18 @@ class PythonApplicationGenerator(
                    self.run_lambda_handler(py)
                }
                /// Build the service and start a single worker.
                ##[pyo3(text_signature = "(${'$'}self, socket, worker_number)")]
                ##[pyo3(text_signature = "(${'$'}self, socket, worker_number, tls)")]
                pub fn start_worker(
                    &mut self,
                    py: pyo3::Python,
                    socket: &pyo3::PyCell<#{SmithyPython}::PySocket>,
                    worker_number: isize,
                    tls: Option<#{SmithyPython}::tls::PyTlsConfig>,
                ) -> pyo3::PyResult<()> {
                    use #{SmithyPython}::PyApp;
                    let event_loop = self.configure_python_event_loop(py)?;
                    let service = self.build_and_configure_service(py, event_loop)?;
                    self.start_hyper_worker(py, socket, event_loop, service, worker_number)
                    self.start_hyper_worker(py, socket, event_loop, service, worker_number, tls)
                }
                """,
                *codegenScope,
+17 −0
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@ class PythonServerModuleGenerator(
                renderPySocketType()
                renderPyLogging()
                renderPyMiddlewareTypes()
                renderPyTlsTypes()
                renderPyApplicationType()
            }
        }
@@ -162,6 +163,22 @@ class PythonServerModuleGenerator(
        )
    }

    private fun RustWriter.renderPyTlsTypes() {
        rustTemplate(
            """
            let tls = #{pyo3}::types::PyModule::new(py, "tls")?;
            tls.add_class::<#{SmithyPython}::tls::PyTlsConfig>()?;
            pyo3::py_run!(
                py,
                tls,
                "import sys; sys.modules['$libName.tls'] = tls"
            );
            m.add_submodule(tls)?;
            """,
            *codegenScope,
        )
    }

    // Render Python application type.
    private fun RustWriter.renderPyApplicationType() {
        rustTemplate(
+5 −0
Original line number Diff line number Diff line
@@ -26,6 +26,9 @@ futures = "0.3"
http = "0.2"
hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] }
lambda_http = { version = "0.7.1", optional = true }
tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h2"] }
rustls-pemfile = "1.0.1"
tokio-rustls = "0.23.4"
num_cpus = "1.13.1"
parking_lot = "0.12.1"
pin-project-lite = "0.2"
@@ -47,6 +50,8 @@ futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }
rcgen = "0.10.0"
hyper-rustls = { version = "0.23.1", features = ["http2"] }

[[test]]
name = "middleware_tests"
+5 −0
Original line number Diff line number Diff line
@@ -9,8 +9,13 @@ description = "Run tests against the Python server implementation"
[dev-dependencies]
command-group = "1.0"
tokio = { version = "1.20.1", features = ["full"] }
serial_test = "0.9.0"
rustls-pemfile = "1.0.1"
tokio-rustls = "0.23.4"
hyper-rustls = { version = "0.23.0", features = ["http2"] }

# Local paths
aws-smithy-client  = { path = "../../../aws-smithy-client/", features = ["rustls"] }
aws-smithy-http = { path = "../../../aws-smithy-http/" }
aws-smithy-types  = { path = "../../../aws-smithy-types/" }
pokemon-service-client = { path = "../pokemon-service-client/" }
+102 −19
Original line number Diff line number Diff line
@@ -3,9 +3,61 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use std::fs::File;
use std::io::BufReader;
use std::process::Command;
use std::time::Duration;

use aws_smithy_client::{erase::DynConnector, hyper_ext::Adapter};
use aws_smithy_http::operation::Request;
use command_group::{CommandGroup, GroupChild};
use pokemon_service_client::{Builder, Client, Config};
use std::{process::Command, thread, time::Duration};
use tokio::time;

const TEST_KEY: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/testdata/localhost.key");
const TEST_CERT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/testdata/localhost.crt");

pub type PokemonClient = Client<
    aws_smithy_client::erase::DynConnector,
    aws_smithy_client::erase::DynMiddleware<aws_smithy_client::erase::DynConnector>,
>;

enum PokemonServiceVariant {
    Http,
    Http2,
}

impl PokemonServiceVariant {
    async fn run_process(&self) -> GroupChild {
        let mut args = vec!["../pokemon_service.py".to_string()];

        match self {
            PokemonServiceVariant::Http => {}
            PokemonServiceVariant::Http2 => {
                args.push("--enable-tls".to_string());
                args.push(format!("--tls-key-path={TEST_KEY}"));
                args.push(format!("--tls-cert-path={TEST_CERT}"));
            }
        }

        let process = Command::new("python3")
            .args(args)
            .group_spawn()
            .expect("failed to spawn the Pokémon Service program");

        // The Python interpreter takes a little to startup.
        time::sleep(Duration::from_secs(2)).await;

        process
    }

    fn base_url(&self) -> &'static str {
        match self {
            PokemonServiceVariant::Http => "http://localhost:13734",
            PokemonServiceVariant::Http2 => "https://localhost:13734",
        }
    }
}

pub(crate) struct PokemonService {
    // We need to ensure all processes forked by the Python interpreter
@@ -16,15 +68,16 @@ pub(crate) struct PokemonService {

impl PokemonService {
    #[allow(dead_code)]
    pub(crate) fn run() -> Self {
        let process = Command::new("python3")
            .arg("../pokemon_service.py")
            .group_spawn()
            .expect("failed to spawn the Pokémon Service program");
        // The Python interpreter takes a little to startup.
        thread::sleep(Duration::from_secs(2));
    pub(crate) async fn run() -> Self {
        Self {
            child_process: process,
            child_process: PokemonServiceVariant::Http.run_process().await,
        }
    }

    #[allow(dead_code)]
    pub(crate) async fn run_http2() -> Self {
        Self {
            child_process: PokemonServiceVariant::Http2.run_process().await,
        }
    }
}
@@ -39,19 +92,49 @@ impl Drop for PokemonService {
}

#[allow(dead_code)]
pub fn client() -> Client<
    aws_smithy_client::erase::DynConnector,
    aws_smithy_client::erase::DynMiddleware<aws_smithy_client::erase::DynConnector>,
> {
pub fn client() -> PokemonClient {
    let base_url = PokemonServiceVariant::Http.base_url();
    let raw_client = Builder::new()
        .rustls_connector(Default::default())
        .middleware_fn(|mut req| {
            let http_req = req.http_mut();
            let uri = format!("http://localhost:13734{}", http_req.uri().path());
            *http_req.uri_mut() = uri.parse().unwrap();
            req
        })
        .middleware_fn(rewrite_base_url(base_url))
        .build_dyn();
    let config = Config::builder().build();
    Client::with_config(raw_client, config)
}

#[allow(dead_code)]
pub fn http2_client() -> PokemonClient {
    // Create custom cert store and add our test certificate to prevent unknown cert issues.
    let mut reader = BufReader::new(File::open(TEST_CERT).expect("could not open certificate"));
    let certs = rustls_pemfile::certs(&mut reader).expect("could not parse certificate");
    let mut roots = tokio_rustls::rustls::RootCertStore::empty();
    roots.add_parsable_certificates(&certs);

    let connector = hyper_rustls::HttpsConnectorBuilder::new()
        .with_tls_config(
            tokio_rustls::rustls::ClientConfig::builder()
                .with_safe_defaults()
                .with_root_certificates(roots)
                .with_no_client_auth(),
        )
        .https_only()
        .enable_http2()
        .build();

    let base_url = PokemonServiceVariant::Http2.base_url();
    let raw_client = Builder::new()
        .connector(DynConnector::new(Adapter::builder().build(connector)))
        .middleware_fn(rewrite_base_url(base_url))
        .build_dyn();
    let config = Config::builder().build();
    Client::with_config(raw_client, config)
}

fn rewrite_base_url(base_url: &'static str) -> impl Fn(Request) -> Request + Clone {
    move |mut req| {
        let http_req = req.http_mut();
        let uri = format!("{base_url}{}", http_req.uri().path());
        *http_req.uri_mut() = uri.parse().unwrap();
        req
    }
}
Loading