Unverified Commit 8293b677 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Allow to wrap routes with Tower Layers (#869)

parent 2c1fe88d
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -32,7 +32,8 @@ regex = "1.0"
serde_urlencoded = "0.7"
thiserror = "1"
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4" }
tower = "0.4"
tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] }

[dev-dependencies]
pretty_assertions = "1"
+32 −7
Original line number Diff line number Diff line
@@ -16,13 +16,16 @@
//! [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait

use self::{future::RouterFuture, request_spec::RequestSpec};
use crate::body::{Body, BoxBody};
use crate::body::{box_body, Body, BoxBody, HttpBody};
use crate::BoxError;
use http::{Request, Response, StatusCode};
use std::{
    convert::Infallible,
    task::{Context, Poll},
};
use tower::{Service, ServiceExt};
use tower::layer::Layer;
use tower::{Service, ServiceBuilder, ServiceExt};
use tower_http::map_response_body::MapResponseBodyLayer;

pub mod future;
mod into_make_service;
@@ -34,7 +37,7 @@ pub use self::{into_make_service::IntoMakeService, route::Route};

#[derive(Debug)]
pub struct Router<B = Body> {
    routes: Vec<Route<B>>,
    routes: Vec<(Route<B>, RequestSpec)>,
}

impl<B> Clone for Router<B> {
@@ -70,7 +73,7 @@ where
        T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone + Send + 'static,
        T::Future: Send + 'static,
    {
        self.routes.push(Route::new(svc, request_spec));
        self.routes.push((Route::new(svc), request_spec));
        self
    }

@@ -84,6 +87,28 @@ where
    pub fn into_make_service(self) -> IntoMakeService<Self> {
        IntoMakeService::new(self)
    }

    /// Apply a [`tower::Layer`] to the router.
    ///
    /// All requests to the router will be processed by the layer's
    /// corresponding middleware.
    ///
    /// This can be used to add additional processing to a request for a group
    /// of routes.
    pub fn layer<L, NewReqBody, NewResBody>(self, layer: L) -> Router<NewReqBody>
    where
        L: Layer<Route<B>>,
        L::Service:
            Service<Request<NewReqBody>, Response = Response<NewResBody>, Error = Infallible> + Clone + Send + 'static,
        <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
        NewResBody: HttpBody<Data = bytes::Bytes> + Send + 'static,
        NewResBody::Error: Into<BoxError>,
    {
        let layer = ServiceBuilder::new().layer_fn(Route::new).layer(MapResponseBodyLayer::new(box_body)).layer(layer);
        let routes =
            self.routes.into_iter().map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)).collect();
        Router { routes }
    }
}

impl<B> Service<Request<B>> for Router<B>
@@ -103,8 +128,8 @@ where
    fn call(&mut self, req: Request<B>) -> Self::Future {
        let mut method_not_allowed = false;

        for route in &self.routes {
            match route.matches(&req) {
        for (route, request_spec) in &self.routes {
            match request_spec.matches(&req) {
                request_spec::Match::Yes => {
                    return RouterFuture::from_oneshot(route.clone().oneshot(req));
                }
@@ -148,7 +173,7 @@ mod tests {

        #[inline]
        fn call(&mut self, req: Request<B>) -> Self::Future {
            let body = box_body(Body::from(format!("{} :: {}", self.0, String::from(req.uri().to_string()))));
            let body = box_body(Body::from(format!("{} :: {}", self.0, req.uri().to_string())));
            let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) };
            Box::pin(fut)
        }
+3 −10
Original line number Diff line number Diff line
@@ -48,31 +48,24 @@ use std::{
use tower::Service;
use tower::{util::Oneshot, ServiceExt};

use super::request_spec::{Match, RequestSpec};

/// How routes are stored inside a [`Router`](super::Router).
pub struct Route<B = Body> {
    service: CloneBoxService<Request<B>, Response<BoxBody>, Infallible>,
    request_spec: RequestSpec,
}

impl<B> Route<B> {
    pub(super) fn new<T>(svc: T, request_spec: RequestSpec) -> Self
    pub(super) fn new<T>(svc: T) -> Self
    where
        T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone + Send + 'static,
        T::Future: Send + 'static,
    {
        Self { service: CloneBoxService::new(svc), request_spec }
    }

    pub(super) fn matches(&self, req: &Request<B>) -> Match {
        self.request_spec.matches(req)
        Self { service: CloneBoxService::new(svc) }
    }
}

impl<ReqBody> Clone for Route<ReqBody> {
    fn clone(&self) -> Self {
        Self { service: self.service.clone(), request_spec: self.request_spec.clone() }
        Self { service: self.service.clone() }
    }
}