Unverified Commit 9057bd1f authored by Harry Barber's avatar Harry Barber Committed by GitHub
Browse files

Improve extractor errors (#2041)

parent 6ca5ed21
Loading
Loading
Loading
Loading
+22 −6
Original line number Diff line number Diff line
@@ -12,14 +12,31 @@
//! and [`ConnectInfo`] with a service builder.

use http::request::Parts;
use thiserror::Error;

use crate::Extension;
use crate::{body::BoxBody, response::IntoResponse};

use super::FromParts;
use super::{internal_server_error, FromParts};

/// The [`ConnectInfo`] was not found in the [`http::Request`] extensions.
///
/// Use [`IntoMakeServiceWithConnectInfo`](crate::routing::IntoMakeServiceWithConnectInfo) to ensure it's present.
#[non_exhaustive]
#[derive(Debug, Error)]
#[error(
    "`ConnectInfo` is not present in the `http::Request` extensions - consider using `aws_smithy_http_server::routing::IntoMakeServiceWithConnectInfo`"
)]
pub struct MissingConnectInfo;

impl<Protocol> IntoResponse<Protocol> for MissingConnectInfo {
    fn into_response(self) -> http::Response<BoxBody> {
        internal_server_error()
    }
}

/// Extractor for getting connection information produced by a `Connected`.
///
/// Note this extractor requires the existence of [`Extension<ConnectInfo<T>>`] in the [`http::Extensions`]. This is
/// Note this extractor requires the existence of [`ConnectInfo<T>`] in the [`http::Extensions`]. This is
/// automatically inserted by the [`IntoMakeServiceWithConnectInfo`](crate::routing::IntoMakeServiceWithConnectInfo)
/// middleware, which can be applied using the `into_make_service_with_connect_info` method on your generated service.
#[derive(Clone, Debug)]
@@ -29,10 +46,9 @@ impl<P, T> FromParts<P> for ConnectInfo<T>
where
    T: Send + Sync + 'static,
{
    type Rejection = <Extension<Self> as FromParts<P>>::Rejection;
    type Rejection = MissingConnectInfo;

    fn from_parts(parts: &mut Parts) -> Result<Self, Self::Rejection> {
        let Extension(connect_info) = <Extension<Self> as FromParts<P>>::from_parts(parts)?;
        Ok(connect_info)
        parts.extensions.remove().ok_or(MissingConnectInfo)
    }
}
+5 −9
Original line number Diff line number Diff line
@@ -50,14 +50,11 @@

use std::ops::Deref;

use http::StatusCode;
use thiserror::Error;

use crate::{
    body::{empty, BoxBody},
    request::FromParts,
    response::IntoResponse,
};
use crate::{body::BoxBody, request::FromParts, response::IntoResponse};

use super::internal_server_error;

/// Generic extension type stored in and extracted from [request extensions].
///
@@ -79,15 +76,14 @@ impl<T> Deref for Extension<T> {
}

/// The extension has not been added to the [`Request`](http::Request) or has been previously removed.
#[non_exhaustive]
#[derive(Debug, Error)]
#[error("the `Extension` is not present in the `http::Request`")]
pub struct MissingExtension;

impl<Protocol> IntoResponse<Protocol> for MissingExtension {
    fn into_response(self) -> http::Response<BoxBody> {
        let mut response = http::Response::new(empty());
        *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
        response
        internal_server_error()
    }
}

+79 −11
Original line number Diff line number Diff line
@@ -12,41 +12,109 @@ pub use lambda_http::{
    aws_lambda_events::apigw::{ApiGatewayProxyRequestContext, ApiGatewayV2httpRequestContext},
    Context,
};
use thiserror::Error;

use super::{extension::MissingExtension, FromParts};
use crate::Extension;
use super::{internal_server_error, FromParts};
use crate::{body::BoxBody, response::IntoResponse};

/// The [`Context`] was not found in the [`http::Request`] extensions.
///
/// Use [`LambdaHandler`](crate::routing::LambdaHandler) to ensure it's present.
#[non_exhaustive]
#[derive(Debug, Error)]
#[error("`Context` is not present in the `http::Request` extensions - consider using `aws_smithy_http_server::routing::LambdaHandler`")]
pub struct MissingContext;

impl<Protocol> IntoResponse<Protocol> for MissingContext {
    fn into_response(self) -> http::Response<BoxBody> {
        internal_server_error()
    }
}

impl<P> FromParts<P> for Context {
    type Rejection = MissingExtension;
    type Rejection = MissingContext;

    fn from_parts(parts: &mut http::request::Parts) -> Result<Self, Self::Rejection> {
        let Extension(context) = <Extension<Self> as FromParts<P>>::from_parts(parts)?;
        Ok(context)
        parts.extensions.remove().ok_or(MissingContext)
    }
}

#[derive(Debug, Error)]
enum MissingGatewayContextTypeV1 {
    #[error("`RequestContext` is not present in the `http::Request` extensions - consider using `aws_smithy_http_server::routing::LambdaHandler`")]
    MissingRequestContext,
    #[error("`RequestContext::ApiGatewayV2` is present in the `http::Request` extensions - consider using the `aws_smithy_http_server::request::lambda::ApiGatewayV2httpRequestContext` extractor")]
    VersionMismatch,
}

/// The [`RequestContext::ApiGatewayV1`] was not found in the [`http::Request`] extensions.
///
/// Use [`LambdaHandler`](crate::routing::LambdaHandler) to ensure it's present and ensure that you're using "ApiGatewayV1".
#[derive(Debug, Error)]
#[error("{inner}")]
pub struct MissingGatewayContextV1 {
    inner: MissingGatewayContextTypeV1,
}

impl<Protocol> IntoResponse<Protocol> for MissingGatewayContextV1 {
    fn into_response(self) -> http::Response<BoxBody> {
        internal_server_error()
    }
}

impl<P> FromParts<P> for ApiGatewayProxyRequestContext {
    type Rejection = MissingExtension;
    type Rejection = MissingGatewayContextV1;

    fn from_parts(parts: &mut http::request::Parts) -> Result<Self, Self::Rejection> {
        let Extension(context) = <Extension<RequestContext> as FromParts<P>>::from_parts(parts)?;
        let context = parts.extensions.remove().ok_or(MissingGatewayContextV1 {
            inner: MissingGatewayContextTypeV1::MissingRequestContext,
        })?;
        if let RequestContext::ApiGatewayV1(context) = context {
            Ok(context)
        } else {
            Err(MissingExtension)
            Err(MissingGatewayContextV1 {
                inner: MissingGatewayContextTypeV1::VersionMismatch,
            })
        }
    }
}

#[derive(Debug, Error)]
enum MissingGatewayContextTypeV2 {
    #[error("`RequestContext` is not present in the `http::Request` extensions - consider using `aws_smithy_http_server::routing::LambdaHandler`")]
    MissingRequestContext,
    #[error("`RequestContext::ApiGatewayV1` is present in the `http::Request` extensions - consider using the `aws_smithy_http_server::request::lambda::ApiGatewayProxyRequestContext` extractor")]
    VersionMismatch,
}

/// The [`RequestContext::ApiGatewayV2`] was not found in the [`http::Request`] extensions.
///
/// Use [`LambdaHandler`](crate::routing::LambdaHandler) to ensure it's present and ensure that you're using "ApiGatewayV2".
#[derive(Debug, Error)]
#[error("{inner}")]
pub struct MissingGatewayContextV2 {
    inner: MissingGatewayContextTypeV2,
}

impl<Protocol> IntoResponse<Protocol> for MissingGatewayContextV2 {
    fn into_response(self) -> http::Response<BoxBody> {
        internal_server_error()
    }
}

impl<P> FromParts<P> for ApiGatewayV2httpRequestContext {
    type Rejection = MissingExtension;
    type Rejection = MissingGatewayContextV2;

    fn from_parts(parts: &mut http::request::Parts) -> Result<Self, Self::Rejection> {
        let Extension(context) = <Extension<RequestContext> as FromParts<P>>::from_parts(parts)?;
        let context = parts.extensions.remove().ok_or(MissingGatewayContextV2 {
            inner: MissingGatewayContextTypeV2::MissingRequestContext,
        })?;
        if let RequestContext::ApiGatewayV2(context) = context {
            Ok(context)
        } else {
            Err(MissingExtension)
            Err(MissingGatewayContextV2 {
                inner: MissingGatewayContextTypeV2::VersionMismatch,
            })
        }
    }
}
+12 −2
Original line number Diff line number Diff line
@@ -46,9 +46,13 @@ use futures_util::{
    future::{try_join, MapErr, MapOk, TryJoin},
    TryFutureExt,
};
use http::{request::Parts, Extensions, HeaderMap, Request, Uri};
use http::{request::Parts, Extensions, HeaderMap, Request, StatusCode, Uri};

use crate::{rejection::any_rejections, response::IntoResponse};
use crate::{
    body::{empty, BoxBody},
    rejection::any_rejections,
    response::IntoResponse,
};

pub mod connect_info;
pub mod extension;
@@ -56,6 +60,12 @@ pub mod extension;
#[cfg_attr(docsrs, doc(cfg(feature = "aws-lambda")))]
pub mod lambda;

fn internal_server_error() -> http::Response<BoxBody> {
    let mut response = http::Response::new(empty());
    *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
    response
}

#[doc(hidden)]
#[derive(Debug)]
pub struct RequestParts<B> {