Unverified Commit 778539f4 authored by Harry Barber's avatar Harry Barber Committed by GitHub
Browse files

Add protocol specific routers (#1666)

* Add protocol specific routers

* Replace internals of `Router` with protocol specific routers
parent 30d985ad
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -7,6 +7,18 @@
use crate::rejection::MissingContentTypeReason;
use crate::request::RequestParts;

/// [AWS REST JSON 1.0 Protocol](https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restjson1-protocol.html).
pub struct AwsRestJson1;

/// [AWS REST XML Protocol](https://awslabs.github.io/smithy/2.0/aws/protocols/aws-restxml-protocol.html).
pub struct AwsRestXml;

/// [AWS JSON 1.0 Protocol](https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html).
pub struct AwsJson10;

/// [AWS JSON 1.1 Protocol](https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_1-protocol.html).
pub struct AwsJson11;

/// Supported protocols.
#[derive(Debug, Clone, Copy)]
pub enum Protocol {
+6 −0
Original line number Diff line number Diff line
@@ -36,3 +36,9 @@ use crate::body::BoxBody;

#[doc(hidden)]
pub type Response<T = BoxBody> = http::Response<T>;

/// A protocol aware function taking `self` to [`http::Response`].
pub trait IntoResponse<Protocol> {
    /// Performs a conversion into a [`http::Response`].
    fn into_response(self) -> http::Response<BoxBody>;
}
+2 −20
Original line number Diff line number Diff line
@@ -33,29 +33,11 @@
 */

//! Future types.
use crate::body::BoxBody;
use futures_util::future::Either;
use http::{Request, Response};
use std::{convert::Infallible, future::ready};
use tower::util::Oneshot;

use super::Route;
pub use super::{into_make_service::IntoMakeService, route::RouteFuture};

type OneshotRoute<B> = Oneshot<super::Route<B>, Request<B>>;
type ReadyResponse = std::future::Ready<Result<Response<BoxBody>, Infallible>>;

opaque_future! {
    /// Response future for [`Router`](super::Router).
    pub type RouterFuture<B> =
        futures_util::future::Either<OneshotRoute<B>, ReadyResponse>;
}

impl<B> RouterFuture<B> {
    pub(super) fn from_oneshot(future: Oneshot<super::Route<B>, Request<B>>) -> Self {
        Self::new(Either::Left(future))
    }

    pub(super) fn from_response(response: Response<BoxBody>) -> Self {
        Self::new(Either::Right(ready(Ok(response))))
    }
    pub type RouterFuture<B> = super::routers::RoutingFuture<Route<B>, B>;
}
+58 −155
Original line number Diff line number Diff line
@@ -8,18 +8,17 @@
//! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html

use self::request_spec::RequestSpec;
use self::tiny_map::TinyMap;
use self::routers::{aws_json::AwsJsonRouter, rest::RestRouter, RoutingService};
use crate::body::{boxed, Body, BoxBody, HttpBody};
use crate::error::BoxError;
use crate::protocols::Protocol;
use crate::runtime_error::{RuntimeError, RuntimeErrorKind};
use http::{Request, Response, StatusCode};
use crate::protocols::{AwsJson10, AwsJson11, AwsRestJson1, AwsRestXml};

use http::{Request, Response};
use std::{
    convert::Infallible,
    task::{Context, Poll},
};
use tower::layer::Layer;
use tower::util::ServiceExt;
use tower::{Service, ServiceBuilder};
use tower_http::map_response_body::MapResponseBodyLayer;

@@ -31,6 +30,7 @@ mod lambda_handler;
pub mod request_spec;

mod route;
mod routers;
mod tiny_map;

pub use self::lambda_handler::LambdaHandler;
@@ -61,11 +61,6 @@ pub struct Router<B = Body> {
    routes: Routes<B>,
}

// This constant determines when the `TinyMap` implementation switches from being a `Vec` to a
// `HashMap`. This is chosen to be 15 as a result of the discussion around
// https://github.com/awslabs/smithy-rs/pull/1429#issuecomment-1147516546
const ROUTE_CUTOFF: usize = 15;

/// Protocol-aware routes types.
///
/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the
@@ -75,10 +70,10 @@ const ROUTE_CUTOFF: usize = 15;
/// directly found in the `X-Amz-Target` HTTP header.
#[derive(Debug)]
enum Routes<B = Body> {
    RestXml(Vec<(Route<B>, RequestSpec)>),
    RestJson1(Vec<(Route<B>, RequestSpec)>),
    AwsJson10(TinyMap<String, Route<B>, ROUTE_CUTOFF>),
    AwsJson11(TinyMap<String, Route<B>, ROUTE_CUTOFF>),
    RestXml(RoutingService<RestRouter<Route<B>>, AwsRestXml>),
    RestJson1(RoutingService<RestRouter<Route<B>>, AwsRestJson1>),
    AwsJson10(RoutingService<AwsJsonRouter<Route<B>>, AwsJson10>),
    AwsJson11(RoutingService<AwsJsonRouter<Route<B>>, AwsJson11>),
}

impl<B> Clone for Router<B> {
@@ -104,29 +99,6 @@ impl<B> Router<B>
where
    B: Send + 'static,
{
    /// Return the correct, protocol-specific "Not Found" response for an unknown operation.
    fn unknown_operation(&self) -> RouterFuture<B> {
        let protocol = match &self.routes {
            Routes::RestJson1(_) => Protocol::RestJson1,
            Routes::RestXml(_) => Protocol::RestXml,
            Routes::AwsJson10(_) => Protocol::AwsJson10,
            Routes::AwsJson11(_) => Protocol::AwsJson11,
        };
        let error = RuntimeError {
            protocol,
            kind: RuntimeErrorKind::UnknownOperation,
        };
        RouterFuture::from_response(error.into_response())
    }

    /// Return the HTTP error response for non allowed method.
    fn method_not_allowed(&self) -> RouterFuture<B> {
        RouterFuture::from_response({
            let mut res = Response::new(crate::body::empty());
            *res.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
            res
        })
    }
    /// Convert this router into a [`MakeService`], that is a [`Service`] whose
    /// response is another service.
    ///
@@ -155,46 +127,21 @@ where
        NewResBody::Error: Into<BoxError>,
    {
        let layer = ServiceBuilder::new()
            .layer_fn(Route::new)
            .layer(MapResponseBodyLayer::new(boxed))
            .layer(layer);
        match self.routes {
            Routes::RestJson1(routes) => {
                let routes = routes
                    .into_iter()
                    .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
                    .collect();
                Router {
                    routes: Routes::RestJson1(routes),
                }
            }
            Routes::RestXml(routes) => {
                let routes = routes
                    .into_iter()
                    .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
                    .collect();
                Router {
                    routes: Routes::RestXml(routes),
                }
            }
            Routes::AwsJson10(routes) => {
                let routes = routes
                    .into_iter()
                    .map(|(operation, route)| (operation, Layer::layer(&layer, route)))
                    .collect();
                Router {
                    routes: Routes::AwsJson10(routes),
                }
            }
            Routes::AwsJson11(routes) => {
                let routes = routes
                    .into_iter()
                    .map(|(operation, route)| (operation, Layer::layer(&layer, route)))
                    .collect();
                Router {
                    routes: Routes::AwsJson11(routes),
                }
            }
            Routes::RestJson1(routes) => Router {
                routes: Routes::RestJson1(routes.map(|router| router.layer(layer).boxed())),
            },
            Routes::RestXml(routes) => Router {
                routes: Routes::RestXml(routes.map(|router| router.layer(layer).boxed())),
            },
            Routes::AwsJson10(routes) => Router {
                routes: Routes::AwsJson10(routes.map(|router| router.layer(layer).boxed())),
            },
            Routes::AwsJson11(routes) => Router {
                routes: Routes::AwsJson11(routes.map(|router| router.layer(layer).boxed())),
            },
        }
    }

@@ -211,18 +158,14 @@ where
            ),
        >,
    {
        let mut routes: Vec<(Route<B>, RequestSpec)> = routes
        let svc = RoutingService::new(
            routes
                .into_iter()
            .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
            .collect();

        // Sort them once by specifity, with the more specific routes sorted before the less
        // specific ones, so that when routing a request we can simply iterate through the routes
        // and pick the first one that matches.
        routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));

                .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc)))
                .collect(),
        );
        Self {
            routes: Routes::RestJson1(routes),
            routes: Routes::RestJson1(svc),
        }
    }

@@ -239,18 +182,14 @@ where
            ),
        >,
    {
        let mut routes: Vec<(Route<B>, RequestSpec)> = routes
        let svc = RoutingService::new(
            routes
                .into_iter()
            .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec))
            .collect();

        // Sort them once by specifity, with the more specific routes sorted before the less
        // specific ones, so that when routing a request we can simply iterate through the routes
        // and pick the first one that matches.
        routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank()));

                .map(|(svc, request_spec)| (request_spec, Route::from_box_clone_service(svc)))
                .collect(),
        );
        Self {
            routes: Routes::RestXml(routes),
            routes: Routes::RestXml(svc),
        }
    }

@@ -267,13 +206,15 @@ where
            ),
        >,
    {
        let routes = routes
        let svc = RoutingService::new(
            routes
                .into_iter()
                .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
            .collect();
                .collect(),
        );

        Self {
            routes: Routes::AwsJson10(routes),
            routes: Routes::AwsJson10(svc),
        }
    }

@@ -290,13 +231,15 @@ where
            ),
        >,
    {
        let routes = routes
        let svc = RoutingService::new(
            routes
                .into_iter()
                .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc)))
            .collect();
                .collect(),
        );

        Self {
            routes: Routes::AwsJson11(routes),
            routes: Routes::AwsJson11(svc),
        }
    }
}
@@ -316,55 +259,15 @@ where

    #[inline]
    fn call(&mut self, req: Request<B>) -> Self::Future {
        match &self.routes {
        let fut = match &mut self.routes {
            // REST routes.
            Routes::RestJson1(routes) | Routes::RestXml(routes) => {
                let mut method_not_allowed = false;

                // Loop through all the routes and validate if any of them matches. Routes are already ranked.
                for (route, request_spec) in routes {
                    match request_spec.matches(&req) {
                        request_spec::Match::Yes => {
                            return RouterFuture::from_oneshot(route.clone().oneshot(req));
                        }
                        request_spec::Match::MethodNotAllowed => method_not_allowed = true,
                        // Continue looping to see if another route matches.
                        request_spec::Match::No => continue,
                    }
                }

                if method_not_allowed {
                    // The HTTP method is not correct.
                    self.method_not_allowed()
                } else {
                    // In any other case return the `RuntimeError::UnknownOperation`.
                    self.unknown_operation()
                }
            }
            Routes::RestJson1(routes) => routes.call(req),
            Routes::RestXml(routes) => routes.call(req),
            // AwsJson routes.
            Routes::AwsJson10(routes) | Routes::AwsJson11(routes) => {
                if req.uri() == "/" {
                    // Check the request method for POST.
                    if req.method() == http::Method::POST {
                        // Find the `x-amz-target` header.
                        if let Some(target) = req.headers().get("x-amz-target") {
                            if let Ok(target) = target.to_str() {
                                // Lookup in the `TinyMap` for a route for the target.
                                let route = routes.get(target);
                                if let Some(route) = route {
                                    return RouterFuture::from_oneshot(route.clone().oneshot(req));
                                }
                            }
                        }
                    } else {
                        // The HTTP method is not POST.
                        return self.method_not_allowed();
                    }
                }
                // In any other case return the `RuntimeError::UnknownOperation`.
                self.unknown_operation()
            }
        }
            Routes::AwsJson10(routes) => routes.call(req),
            Routes::AwsJson11(routes) => routes.call(req),
        };
        RouterFuture::new(fut)
    }
}

@@ -376,7 +279,7 @@ mod rest_tests {
        routing::request_spec::*,
    };
    use futures_util::Future;
    use http::{HeaderMap, Method};
    use http::{HeaderMap, Method, StatusCode};
    use std::pin::Pin;

    /// Helper function to build a `Request`. Used in other test modules.
@@ -601,7 +504,7 @@ mod awsjson_tests {
    use super::*;
    use crate::body::boxed;
    use futures_util::Future;
    use http::{HeaderMap, HeaderValue, Method};
    use http::{HeaderMap, HeaderValue, Method, StatusCode};
    use pretty_assertions::assert_eq;
    use std::pin::Pin;

+154 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use std::convert::Infallible;

use http::header::ToStrError;
use thiserror::Error;
use tower::{Layer, Service};

use crate::{
    body::{empty, BoxBody},
    extension::RuntimeErrorExtension,
    protocols::{AwsJson10, AwsJson11},
    response::IntoResponse,
    routing::{tiny_map::TinyMap, Route},
};

use super::Router;

/// An AWS JSON routing error.
#[derive(Debug, Error)]
pub enum Error {
    /// Relative URI was not "/".
    #[error("relative URI is not \"/\"")]
    NotRootUrl,
    /// Method was not `POST`.
    #[error("method not POST")]
    MethodNotAllowed,
    /// Missing the `x-amz-target` header.
    #[error("missing the \"x-amz-target\" header")]
    MissingHeader,
    /// Unable to parse header into UTF-8.
    #[error("failed to parse header: {0}")]
    InvalidHeader(ToStrError),
    /// Operation not found.
    #[error("operation not found")]
    NotFound,
}

impl IntoResponse<AwsJson10> for Error {
    fn into_response(self) -> http::Response<BoxBody> {
        match self {
            Error::MethodNotAllowed => super::method_disallowed(),
            _ => http::Response::builder()
                .status(http::StatusCode::NOT_FOUND)
                .header(http::header::CONTENT_TYPE, "application/x-amz-json-1.0")
                .extension(RuntimeErrorExtension::new(
                    super::UNKNOWN_OPERATION_EXCEPTION.to_string(),
                ))
                .body(empty())
                .expect("invalid HTTP response for AWS JSON routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"),
        }
    }
}

impl IntoResponse<AwsJson11> for Error {
    fn into_response(self) -> http::Response<BoxBody> {
        match self {
            Error::MethodNotAllowed => super::method_disallowed(),
            _ => http::Response::builder()
                .status(http::StatusCode::NOT_FOUND)
                .header(http::header::CONTENT_TYPE, "application/x-amz-json-1.1")
                .extension(RuntimeErrorExtension::new(
                    super::UNKNOWN_OPERATION_EXCEPTION.to_string(),
                ))
                .body(empty())
                .expect("invalid HTTP response for AWS JSON routing error; please file a bug report under https://github.com/awslabs/smithy-rs/issues"),
        }
    }
}

// This constant determines when the `TinyMap` implementation switches from being a `Vec` to a
// `HashMap`. This is chosen to be 15 as a result of the discussion around
// https://github.com/awslabs/smithy-rs/pull/1429#issuecomment-1147516546
const ROUTE_CUTOFF: usize = 15;

/// A [`Router`] supporting [`AWS JSON 1.0`] and [`AWS JSON 1.1`] protocols.
///
/// [AWS JSON 1.0]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_0-protocol.html
/// [AWS JSON 1.1]: https://awslabs.github.io/smithy/2.0/aws/protocols/aws-json-1_1-protocol.html
#[derive(Debug, Clone)]
pub struct AwsJsonRouter<S> {
    routes: TinyMap<String, S, ROUTE_CUTOFF>,
}

impl<S> AwsJsonRouter<S> {
    /// Applies a [`Layer`] uniformly to all routes.
    pub fn layer<L>(self, layer: L) -> AwsJsonRouter<L::Service>
    where
        L: Layer<S>,
    {
        AwsJsonRouter {
            routes: self
                .routes
                .into_iter()
                .map(|(key, route)| (key, layer.layer(route)))
                .collect(),
        }
    }

    /// Applies type erasure to the inner route using [`Route::new`].
    pub fn boxed<B>(self) -> AwsJsonRouter<Route<B>>
    where
        S: Service<http::Request<B>, Response = http::Response<BoxBody>, Error = Infallible>,
        S: Send + Clone + 'static,
        S::Future: Send + 'static,
    {
        AwsJsonRouter {
            routes: self.routes.into_iter().map(|(key, s)| (key, Route::new(s))).collect(),
        }
    }
}

impl<B, S> Router<B> for AwsJsonRouter<S>
where
    S: Clone,
{
    type Service = S;
    type Error = Error;

    fn match_route(&self, request: &http::Request<B>) -> Result<S, Self::Error> {
        // The URI must be root,
        if request.uri() != "/" {
            return Err(Error::NotRootUrl);
        }

        // Only `Method::POST` is allowed.
        if request.method() != http::Method::POST {
            return Err(Error::MethodNotAllowed);
        }

        // Find the `x-amz-target` header.
        let target = request.headers().get("x-amz-target").ok_or(Error::MissingHeader)?;
        let target = target.to_str().map_err(Error::InvalidHeader)?;

        // Lookup in the `TinyMap` for a route for the target.
        let route = self.routes.get(target).ok_or(Error::NotFound)?;
        Ok(route.clone())
    }
}

impl<S> FromIterator<(String, S)> for AwsJsonRouter<S> {
    #[inline]
    fn from_iter<T: IntoIterator<Item = (String, S)>>(iter: T) -> Self {
        Self {
            routes: iter
                .into_iter()
                .map(|(svc, request_spec)| (svc, request_spec))
                .collect(),
        }
    }
}
Loading