Unverified Commit d89a90d6 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Add request ID to response headers (#2438)



* Add request ID to response headers

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Add parsing test

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Style

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* CHANGELOG

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Fix import

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Panic if ServerRequestIdProviderLayer is not present

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Own value

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Correct docs

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Add order of layer to expect() message

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Remove Box

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Require order of request ID layers

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Revert "Require order of request ID layers"

This reverts commit 147eef26d5f8972b810dc5994e57865a872e44b0.

* One layer to generate and inject the header

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* HeaderName for header name

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* CHANGELOG

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Remove additional layer

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Remove to_owned

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Add tests, remove unnecessary clone

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* take() ResponsePackage instead

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Update docs

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Update docs

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* cargo fmt

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

* Update CHANGELOG

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

---------

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
Co-authored-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent abbf78fd
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [
references = ["smithy-rs#2474"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"

[[smithy-rs]]
message = """Servers can send the `ServerRequestId` in the response headers.
Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`:
```
let app = app
    .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")));
```
"""
references = ["smithy-rs#2438"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server"}
author = "82marbag"
+133 −9
Original line number Diff line number Diff line
@@ -12,8 +12,10 @@
//! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request.
//! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and
//! data related to a single operation.
//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
//!
//! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service.
//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
//!
//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
//!
@@ -34,7 +36,8 @@
//!     .operation(handler)
//!     .build().unwrap();
//!
//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */
//! let app = app
//!     .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */
//!
//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
//!     .parse()
@@ -42,12 +45,15 @@
//! let server = hyper::Server::bind(&bind).serve(app.into_make_service());
//! ```

use std::future::Future;
use std::{
    fmt::Display,
    task::{Context, Poll},
};

use futures_util::TryFuture;
use http::request::Parts;
use http::{header::HeaderName, HeaderValue, Response};
use thiserror::Error;
use tower::{Layer, Service};
use uuid::Uuid;
@@ -74,6 +80,10 @@ impl ServerRequestId {
    pub fn new() -> Self {
        Self { id: Uuid::new_v4() }
    }

    pub(crate) fn to_header(&self) -> HeaderValue {
        HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
    }
}

impl Display for ServerRequestId {
@@ -99,17 +109,28 @@ impl Default for ServerRequestId {
#[derive(Clone)]
pub struct ServerRequestIdProvider<S> {
    inner: S,
    header_key: Option<HeaderName>,
}

/// A layer that provides services with a unique request ID instance
#[derive(Debug)]
#[non_exhaustive]
pub struct ServerRequestIdProviderLayer;
pub struct ServerRequestIdProviderLayer {
    header_key: Option<HeaderName>,
}

impl ServerRequestIdProviderLayer {
    /// Generate a new unique request ID
    /// Generate a new unique request ID and do not add it as a response header
    /// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header
    pub fn new() -> Self {
        Self {}
        Self { header_key: None }
    }

    /// Generate a new unique request ID and add it as a response header
    pub fn new_with_response_header(header_key: HeaderName) -> Self {
        Self {
            header_key: Some(header_key),
        }
    }
}

@@ -123,25 +144,47 @@ impl<S> Layer<S> for ServerRequestIdProviderLayer {
    type Service = ServerRequestIdProvider<S>;

    fn layer(&self, inner: S) -> Self::Service {
        ServerRequestIdProvider { inner }
        ServerRequestIdProvider {
            inner,
            header_key: self.header_key.clone(),
        }
    }
}

impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
where
    S: Service<http::Request<Body>>,
    S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
    S::Future: std::marker::Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;
    type Future = ServerRequestIdResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
        req.extensions_mut().insert(ServerRequestId::new());
        self.inner.call(req)
        let request_id = ServerRequestId::new();
        match &self.header_key {
            Some(header_key) => {
                req.extensions_mut().insert(request_id.clone());
                ServerRequestIdResponseFuture {
                    response_package: Some(ResponsePackage {
                        request_id,
                        header_key: header_key.clone(),
                    }),
                    fut: self.inner.call(req),
                }
            }
            None => {
                req.extensions_mut().insert(request_id);
                ServerRequestIdResponseFuture {
                    response_package: None,
                    fut: self.inner.call(req),
                }
            }
        }
    }
}

@@ -150,3 +193,84 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
        internal_server_error()
    }
}

struct ResponsePackage {
    request_id: ServerRequestId,
    header_key: HeaderName,
}

pin_project_lite::pin_project! {
    pub struct ServerRequestIdResponseFuture<Fut> {
        response_package: Option<ResponsePackage>,
        #[pin]
        fut: Fut,
    }
}

impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
where
    Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
{
    type Output = Result<Fut::Ok, Fut::Error>;

    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let fut = this.fut;
        let response_package = this.response_package;
        fut.try_poll(cx).map_ok(|mut res| {
            if let Some(response_package) = response_package.take() {
                res.headers_mut()
                    .insert(response_package.header_key, response_package.request_id.to_header());
            }
            res
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::body::{Body, BoxBody};
    use crate::request::Request;
    use http::HeaderValue;
    use std::convert::Infallible;
    use tower::{service_fn, ServiceBuilder, ServiceExt};

    #[test]
    fn test_request_id_parsed_by_header_value_infallible() {
        ServerRequestId::new().to_header();
    }

    #[tokio::test]
    async fn test_request_id_in_response_header() {
        let svc = ServiceBuilder::new()
            .layer(&ServerRequestIdProviderLayer::new_with_response_header(
                HeaderName::from_static("x-request-id"),
            ))
            .service(service_fn(|_req: Request<Body>| async move {
                Ok::<_, Infallible>(Response::new(BoxBody::default()))
            }));

        let req = Request::new(Body::empty());

        let res = svc.oneshot(req).await.unwrap();
        let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();

        assert!(HeaderValue::from_str(request_id).is_ok());
    }

    #[tokio::test]
    async fn test_request_id_not_in_response_header() {
        let svc = ServiceBuilder::new()
            .layer(&ServerRequestIdProviderLayer::new())
            .service(service_fn(|_req: Request<Body>| async move {
                Ok::<_, Infallible>(Response::new(BoxBody::default()))
            }));

        let req = Request::new(Body::empty());

        let res = svc.oneshot(req).await.unwrap();

        assert!(res.headers().is_empty());
    }
}