Unverified Commit 1bb0349e authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Update server TLS example to showcase how to extract connection information (#3023)

Showcase how to pass TLS derived information to server operation
handlers, like the certificate chain presented by the client.

To test, run `RUST_LOG=pokemon_service_tls=DEBUG cargo run` and `curl -v
--insecure https://localhost:13734/do-nothing`. Alternatively, run the
tests using `RUST_LOG=pokemon_service_tls=DEBUG cargo test`, which have
been updated to hit the operation that logs the TLS connection derived
information.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 42cc7cf5
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -202,9 +202,10 @@ jobs:
          popd &>/dev/null
        done

  # We make sure that Smithy-rs can be compiled on platforms that are not natively supported by GitHub actions.
  # We run as many tests we can on those platforms because they require a more complicated setup involving architecture
  # emulation.
  # We make sure that smithy-rs can be compiled on platforms that are not
  # natively supported by GitHub actions. We run as many tests as we can on
  # those platforms, but not all of them, because they require a more
  # complicated setup involving architecture emulation.
  test-exotic-platform-support:
    name: Exotic platform support
    runs-on: ubuntu-latest
@@ -212,7 +213,8 @@ jobs:
      fail-fast: false
      matrix:
        include:
        # We always exclude aws-smithy-http-server-python since the Python framework is experimental.
        # We always exclude `aws-smithy-http-server-python` and
        # `aws-smithy-http-server-typescript` since they are experimental.
        - target: i686-unknown-linux-gnu
          build_smithy_rs_features: --all-features
          build_aws_exclude: ''
+1 −0
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ description = "A smithy Rust service to retrieve information about Pokémon."
clap = { version = "4.1.11", features = ["derive"] }
hyper = { version = "0.14.26", features = ["server"] }
tokio = "1.26.0"
tracing = "0.1"

# These dependencies are only required for the `pokemon-service-tls` program.
tls-listener = { version = "0.7.0", features = ["rustls", "hyper-h2"] }
+52 −7
Original line number Diff line number Diff line
@@ -24,7 +24,9 @@

use std::{fs::File, future, io::BufReader, net::SocketAddr, sync::Arc};

use aws_smithy_http_server::AddExtensionLayer;
use aws_smithy_http_server::{
    request::connect_info::ConnectInfo, routing::Connected, AddExtensionLayer,
};
use clap::Parser;
use futures_util::stream::StreamExt;
use tokio_rustls::{
@@ -33,10 +35,10 @@ use tokio_rustls::{
};

use pokemon_service_common::{
    capture_pokemon, check_health, do_nothing, get_pokemon_species, get_server_statistics,
    get_storage, setup_tracing, stream_pokemon_radio, State,
    capture_pokemon, check_health, get_pokemon_species, get_server_statistics, get_storage,
    setup_tracing, stream_pokemon_radio, State,
};
use pokemon_service_server_sdk::{PokemonService, PokemonServiceConfig};
use pokemon_service_server_sdk::{input, output, PokemonService, PokemonServiceConfig};
use pokemon_service_tls::{DEFAULT_ADDRESS, DEFAULT_PORT, DEFAULT_TEST_CERT, DEFAULT_TEST_KEY};

#[derive(Parser, Debug)]
@@ -56,6 +58,46 @@ struct Args {
    tls_key_path: String,
}

/// Information derived from the TLS connection.
#[derive(Debug, Clone)]
pub struct TlsConnectInfo {
    /// The remote peer address of this connection.
    pub socket_addr: SocketAddr,

    /// The set of TLS certificates presented by the peer in this connection.
    pub certs: Option<Arc<Vec<Certificate>>>,
}

impl Connected<&tokio_rustls::server::TlsStream<hyper::server::conn::AddrStream>>
    for TlsConnectInfo
{
    fn connect_info(
        target: &tokio_rustls::server::TlsStream<hyper::server::conn::AddrStream>,
    ) -> Self {
        let (addr_stream, session) = target.get_ref();
        let socket_addr = addr_stream.remote_addr();

        let certs = session
            .peer_certificates()
            .map(|certs| Arc::new(certs.to_vec()));

        TlsConnectInfo { socket_addr, certs }
    }
}

/// Empty operation used to showcase how we can get access to information derived from the TLS
/// connection in.
pub async fn do_nothing_with_tls_connect_info(
    _input: input::DoNothingInput,
    ConnectInfo(tls_connect_info): ConnectInfo<TlsConnectInfo>,
) -> output::DoNothingOutput {
    // Logging these might pose a security concern! You probably don't want to do this in
    // production.
    tracing::debug!(?tls_connect_info.certs, "peer TLS certificates");

    output::DoNothingOutput {}
}

#[tokio::main]
pub async fn main() {
    let args = Args::parse();
@@ -73,7 +115,7 @@ pub async fn main() {
        .get_storage(get_storage)
        .get_server_statistics(get_server_statistics)
        .capture_pokemon(capture_pokemon)
        .do_nothing(do_nothing)
        .do_nothing(do_nothing_with_tls_connect_info)
        .check_health(check_health)
        .stream_pokemon_radio(stream_pokemon_radio)
        .build()
@@ -96,8 +138,11 @@ pub async fn main() {
            future::ready(true)
        }
    });
    let server = hyper::Server::builder(hyper::server::accept::from_stream(listener))
        .serve(app.into_make_service());
    // Using `into_make_service_with_connect_info`, rather than `into_make_service`, to adjoin the `TlsConnectInfo`
    // connection info.
    let make_app = app.into_make_service_with_connect_info::<TlsConnectInfo>();
    let server =
        hyper::Server::builder(hyper::server::accept::from_stream(listener)).serve(make_app);
    if let Err(err) = server.await {
        eprintln!("server error: {}", err);
    }
+7 −7
Original line number Diff line number Diff line
@@ -7,23 +7,23 @@ pub mod common;

use serial_test::serial;

#[tokio::test]
#[serial]
// This test invokes an operation with a client that can only send HTTP2 requests and whose TLS
// implementation is backed by `rustls`.
async fn test_check_health_http2_rustls_connector() {
#[tokio::test]
#[serial]
async fn test_do_nothing_http2_rustls_connector() {
    let _child = common::run_server().await;
    let client = common::client_http2_only();

    let _check_health = client.check_health().send().await.unwrap();
    let _check_health = client.do_nothing().send().await.unwrap();
}

// This test invokes an operation with a client whose TLS implementation is backed by `native_tls`.
#[tokio::test]
#[serial]
// This test invokes an operation with a client whose TLS implementation is backed by `native_tls`.
async fn test_check_health_native_tls_connector() {
async fn test_do_nothing_native_tls_connector() {
    let _child = common::run_server().await;
    let client = common::native_tls_client();

    let _check_health = client.check_health().send().await.unwrap();
    let _check_health = client.do_nothing().send().await.unwrap();
}
+1 −0
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ regex = "1.5.5"
serde_urlencoded = "0.7"
thiserror = "1.0.40"
tokio = { version = "1.23.1", features = ["full"] }
tokio-rustls = "0.24"
tower = { version = "0.4.11", features = ["util", "make"], default-features = false }
tower-http = { version = "0.3", features = ["add-extension", "map-response-body"] }
tracing = "0.1.35"