Unverified Commit 86539f2b authored by Burak's avatar Burak Committed by GitHub
Browse files

Add TLS support to Pokemon server (#1634)

* Add TLS support to Pokemon server

* Add missing copyright header to TLS module

* Handle connection errors

* Move TLS example to own binary

* Update comments according to reviews

* Move rewrite base url middleware to its own function
parent e185990a
Loading
Loading
Loading
Loading
+16 −1
Original line number Diff line number Diff line
@@ -5,6 +5,11 @@ edition = "2021"
publish = false
authors = ["Smithy-rs Server Team <smithy-rs-server@amazon.com>"]
description = "A smithy Rust service to retrieve information about Pokémon."
default-run = "pokemon-service"

[[bin]]
name = "pokemon-service-tls"
path = "src/bin/pokemon-service-tls.rs"

[dependencies]
async-stream = "0.3"
@@ -17,6 +22,12 @@ tower-http = { version = "0.3", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

# These dependencies are only required for `pokemon-service-tls`.
tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h2"] }
tokio-rustls = "0.23.4"
rustls-pemfile = "1.0.1"
futures-util = "0.3"

# Local paths
aws-smithy-http-server = { path = "../../" }
pokemon-service-server-sdk = { path = "../pokemon-service-server-sdk/" }
@@ -27,6 +38,10 @@ home = "0.5"
serial_test = "0.7.0"
wrk-api-bench = "0.0.7"

# These dependencies are only required for testing `pokemon-service-tls`.
hyper-rustls = { version = "0.23.0", features = ["http2"] }

# Local paths
aws-smithy-client = { path = "../../../aws-smithy-client/", features = ["rustls"] }
aws-smithy-http = { path = "../../../aws-smithy-http/" }
pokemon-service-client = { path = "../pokemon-service-client/" }
+152 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

// This program is exported as a binary named `pokemon-service-tls`.
// It uses `tls-listener`, `tokio-rustls` (and `rustls-pemfile` to parse PEM files)
// to serve TLS connections. It also enables h2 ALPN protocol,
// without this clients by default don't upgrade to http2.
//
// You can use `mkcert` (https://github.com/FiloSottile/mkcert) to create certificates for testing:
// `$ mkcert localhost`
// it should create `./localhost.pem` and `./localhost-key.pem`,
// then you can run TLS server via:
// `$ cargo run --bin pokemon-service-tls -- --tls-cert-path ./localhost.pem --tls-key-path ./localhost-key.pem`
// and test it:
// ```bash
// $ curl -k -D- -H "Accept: application/json" https://localhost:13734/pokemon-species/pikachu
// HTTP/2 200
// # ...
// ```
// note that by default created certificates will be unknown and you should use `-k|--insecure`
// flag while making requests with cURL or you can run `mkcert -install` to trust certificates created by `mkcert`.

use std::fs::File;
use std::future;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::Arc;

use aws_smithy_http_server::{AddExtensionLayer, Router};
use clap::Parser;
use futures_util::stream::StreamExt;
use pokemon_service::{
    capture_pokemon, empty_operation, get_pokemon_species, get_server_statistics, get_storage, health_check_operation,
    setup_tracing, State,
};
use pokemon_service_server_sdk::operation_registry::OperationRegistryBuilder;
use tokio_rustls::{
    rustls::{Certificate, PrivateKey, ServerConfig},
    TlsAcceptor,
};
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
    /// Hyper server bind address.
    #[clap(short, long, action, default_value = "127.0.0.1")]
    address: String,
    /// Hyper server bind port.
    #[clap(short, long, action, default_value = "13734")]
    port: u16,
    /// Hyper server TLS certificate path. Must be a PEM file.
    #[clap(long, default_value = "")]
    tls_cert_path: String,
    /// Hyper server TLS private key path. Must be a PEM file.
    #[clap(long, default_value = "")]
    tls_key_path: String,
}

#[tokio::main]
pub async fn main() {
    let args = Args::parse();
    setup_tracing();
    let app: Router = OperationRegistryBuilder::default()
        // Build a registry containing implementations to all the operations in the service. These
        // are async functions or async closures that take as input the operation's input and
        // return the operation's output.
        .get_pokemon_species(get_pokemon_species)
        .get_storage(get_storage)
        .get_server_statistics(get_server_statistics)
        .capture_pokemon_operation(capture_pokemon)
        .empty_operation(empty_operation)
        .health_check_operation(health_check_operation)
        .build()
        .expect("Unable to build operation registry")
        // Convert it into a router that will route requests to the matching operation
        // implementation.
        .into();

    // Setup shared state and middlewares.
    let shared_state = Arc::new(State::default());
    let app = app.layer(
        ServiceBuilder::new()
            .layer(TraceLayer::new_for_http())
            .layer(AddExtensionLayer::new(shared_state)),
    );

    let addr: SocketAddr = format!("{}:{}", args.address, args.port)
        .parse()
        .expect("unable to parse the server bind address and port");

    let acceptor = acceptor(&args.tls_cert_path, &args.tls_key_path);
    let listener = tls_listener::TlsListener::new(
        acceptor,
        hyper::server::conn::AddrIncoming::bind(&addr).expect("could not bind"),
    )
    .filter(|conn| {
        if let Err(err) = conn {
            eprintln!("connection error: {:?}", err);
            future::ready(false)
        } else {
            future::ready(true)
        }
    });
    let server = hyper::Server::builder(hyper::server::accept::from_stream(listener)).serve(app.into_make_service());
    if let Err(err) = server.await {
        eprintln!("server error: {}", err);
    }
}

// Returns a `TlsAcceptor` that can be used to create `TlsListener`
// which then can be used with Hyper.
pub fn acceptor(cert_path: &str, key_path: &str) -> TlsAcceptor {
    let certs = load_certs(cert_path);
    let key = load_key(key_path);
    let mut server_config = ServerConfig::builder()
        .with_safe_defaults()
        .with_no_client_auth()
        .with_single_cert(certs, key)
        .expect("could not create server config");

    // If we don't state we are accepting "h2", clients by default don't negotiate way up to http2.
    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];

    TlsAcceptor::from(Arc::new(server_config))
}

fn load_certs(path: &str) -> Vec<Certificate> {
    let mut reader = BufReader::new(File::open(path).expect("could not open certificate"));
    rustls_pemfile::certs(&mut reader)
        .expect("could not parse certificate")
        .into_iter()
        .map(Certificate)
        .collect()
}

fn load_key(path: &str) -> PrivateKey {
    let mut reader = BufReader::new(File::open(path).expect("could not open private key"));
    loop {
        match rustls_pemfile::read_one(&mut reader).expect("could not parse private key") {
            Some(rustls_pemfile::Item::RSAKey(key)) => return PrivateKey(key),
            Some(rustls_pemfile::Item::PKCS8Key(key)) => return PrivateKey(key),
            Some(rustls_pemfile::Item::ECKey(key)) => return PrivateKey(key),
            None => break,
            _ => {}
        }
    }
    panic!("invalid private key")
}
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
 * SPDX-License-Identifier: Apache-2.0
 */

// This program is exported as a binary named `pokemon_service`.
// This program is exported as a binary named `pokemon-service`.
use std::{net::SocketAddr, sync::Arc};

use aws_smithy_http_server::{AddExtensionLayer, Router};
+96 −11
Original line number Diff line number Diff line
@@ -3,20 +3,67 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use std::fs::File;
use std::io::BufReader;
use std::process::{Child, Command};
use std::time::Duration;

use assert_cmd::prelude::*;
use aws_smithy_client::{erase::DynConnector, hyper_ext::Adapter};
use aws_smithy_http::operation::Request;
use pokemon_service_client::{Builder, Client, Config};
use std::process::Command;
use tokio::time;

const TEST_KEY: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/testdata/localhost.key");
const TEST_CERT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/testdata/localhost.crt");

enum PokemonServiceVariant {
    Http,
    Https,
}

impl PokemonServiceVariant {
    async fn run_process(&self) -> Child {
        let process = match self {
            PokemonServiceVariant::Http => Command::cargo_bin("pokemon-service").unwrap().spawn().unwrap(),
            PokemonServiceVariant::Https => Command::cargo_bin("pokemon-service-tls")
                .unwrap()
                .args(["--tls-cert-path", TEST_CERT, "--tls-key-path", TEST_KEY])
                .spawn()
                .unwrap(),
        };

        // Give PokémonService some time to start up.
        time::sleep(Duration::from_millis(500)).await;

        process
    }

    fn base_url(&self) -> String {
        match self {
            PokemonServiceVariant::Http => "http://localhost:13734".to_string(),
            PokemonServiceVariant::Https => "https://localhost:13734".to_string(),
        }
    }
}

pub(crate) struct PokemonService {
    child_process: std::process::Child,
    child_process: Child,
}

impl PokemonService {
    #[allow(dead_code)]
    pub(crate) fn run() -> Self {
        let process = Command::cargo_bin("pokemon-service").unwrap().spawn().unwrap();
    pub(crate) async fn run() -> Self {
        Self {
            child_process: PokemonServiceVariant::Http.run_process().await,
        }
    }

        Self { child_process: process }
    #[allow(dead_code)]
    pub(crate) async fn run_https() -> Self {
        Self {
            child_process: PokemonServiceVariant::Https.run_process().await,
        }
    }
}

@@ -33,15 +80,53 @@ pub fn client() -> Client<
    aws_smithy_client::erase::DynConnector,
    aws_smithy_client::erase::DynMiddleware<aws_smithy_client::erase::DynConnector>,
> {
    let base_url = PokemonServiceVariant::Http.base_url();
    let raw_client = Builder::new()
        .rustls()
        .middleware_fn(|mut req| {
            let http_req = req.http_mut();
            let uri = format!("http://localhost:13734{}", http_req.uri().path());
            *http_req.uri_mut() = uri.parse().unwrap();
            req
        })
        .middleware_fn(rewrite_base_url(base_url))
        .build_dyn();
    let config = Config::builder().build();
    Client::with_config(raw_client, config)
}

// Returns a client that only talks through https and http2 connections.
// It is useful in testing whether our server can talk to http2.
#[allow(dead_code)]
pub fn client_http2_only() -> Client<
    aws_smithy_client::erase::DynConnector,
    aws_smithy_client::erase::DynMiddleware<aws_smithy_client::erase::DynConnector>,
> {
    // Create custom cert store and add our test certificate to prevent unknown cert issues.
    let mut reader = BufReader::new(File::open(TEST_CERT).expect("could not open certificate"));
    let certs = rustls_pemfile::certs(&mut reader).expect("could not parse certificate");
    let mut roots = tokio_rustls::rustls::RootCertStore::empty();
    roots.add_parsable_certificates(&certs);

    let connector = hyper_rustls::HttpsConnectorBuilder::new()
        .with_tls_config(
            tokio_rustls::rustls::ClientConfig::builder()
                .with_safe_defaults()
                .with_root_certificates(roots)
                .with_no_client_auth(),
        )
        .https_only()
        .enable_http2()
        .build();

    let base_url = PokemonServiceVariant::Https.base_url();
    let raw_client = Builder::new()
        .connector(DynConnector::new(Adapter::builder().build(connector)))
        .middleware_fn(rewrite_base_url(base_url))
        .build_dyn();
    let config = Config::builder().build();
    Client::with_config(raw_client, config)
}

fn rewrite_base_url(base_url: String) -> impl Fn(Request) -> Request + Clone {
    move |mut req| {
        let http_req = req.http_mut();
        let uri = format!("{base_url}{}", http_req.uri().path());
        *http_req.uri_mut() = uri.parse().unwrap();
        req
    }
}
+12 −13
Original line number Diff line number Diff line
@@ -7,9 +7,7 @@
// These tests only have access to your crate's public API.
// See: https://doc.rust-lang.org/book/ch11-03-test-organization.html#integration-tests

use std::time::Duration;

use crate::helpers::{client, PokemonService};
use crate::helpers::{client, client_http2_only, PokemonService};

use async_stream::stream;
use pokemon_service_client::{
@@ -22,7 +20,6 @@ use pokemon_service_client::{
};
use rand::Rng;
use serial_test::serial;
use tokio::time;

mod helpers;

@@ -48,19 +45,23 @@ fn get_pokemon_to_capture() -> String {
#[tokio::test]
#[serial]
async fn test_health_check_operation() {
    let _program = PokemonService::run();
    // Give PokémonService some time to start up.
    time::sleep(Duration::from_millis(500)).await;
    let _program = PokemonService::run().await;

    let _health_check = client().health_check_operation().send().await.unwrap();
}

#[tokio::test]
#[serial]
async fn test_health_check_operation_http2() {
    // Make sure our server can serve http2
    let _program = PokemonService::run_https().await;
    let _health_check = client_http2_only().health_check_operation().send().await.unwrap();
}

#[tokio::test]
#[serial]
async fn simple_integration_test() {
    let _program = PokemonService::run();
    // Give PokémonService some time to start up.
    time::sleep(Duration::from_millis(500)).await;
    let _program = PokemonService::run().await;

    let service_statistics_out = client().get_server_statistics().send().await.unwrap();
    assert_eq!(0, service_statistics_out.calls_count.unwrap());
@@ -112,9 +113,7 @@ async fn simple_integration_test() {
#[tokio::test]
#[serial]
async fn event_stream_test() {
    let _program = PokemonService::run();
    // Give PokémonService some time to start up.
    time::sleep(Duration::from_millis(500)).await;
    let _program = PokemonService::run().await;

    let mut team = vec![];
    let input_stream = stream! {
Loading