Unverified Commit 430f4bf8 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Make `Content-Type` check in `ServerHttpBoundProtocolGenerator` protocol-agnostic (#1605)

The `Content-Type` check of the incoming request in
`ServerHttpBoundProtocolGenerator` is performed by comparing against the
configured protocol directly within the class.

Delegating to the protocol enables the implementation of protocols in
decorators loaded from the classpath in other projects outside
smithy-rs.

This commit makes the `ServerHttpBoundProtocolGenerator`
protocol-agnostic by looking up the expected incoming request
`Content-Type` header from the protocol implementation directly. To this
end, the macro-generated `check_<protocol>_content_type` functions from
the `aws-smithy-http-server` runtime have been replaced with a single
general `check_content_type` function that takes in the expected mime
type.
parent 738d84cc
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -19,7 +19,9 @@ object ServerCargoDependency {
    val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
    val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1"))
    val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
    val Mime: CargoDependency = CargoDependency("mime", CratesIo("0.3"))
    val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
    val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13"))
    val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
    val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
    val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.0"), scope = DependencyScope.Dev)
+6 −22
Original line number Diff line number Diff line
@@ -127,7 +127,9 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        "header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header"),
        "Hyper" to CargoDependency.Hyper.asType(),
        "LazyStatic" to CargoDependency.LazyStatic.asType(),
        "Mime" to ServerCargoDependency.Mime.asType(),
        "Nom" to ServerCargoDependency.Nom.asType(),
        "OnceCell" to ServerCargoDependency.OnceCell.asType(),
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
@@ -668,13 +670,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
        val parser = structuredDataParser.serverInputParser(operationShape)
        if (parser != null) {
            val contentTypeCheck = getContentTypeCheck()
            val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
            rustTemplate(
                """
                let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?;
                let bytes = #{Hyper}::body::to_bytes(body).await?;
                if !bytes.is_empty() {
                    #{SmithyHttpServer}::protocols::$contentTypeCheck(request)?;
                    static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> =
                        #{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap());
                    #{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?;
                    input = #{parser}(bytes.as_ref(), input)?;
                }
                """,
@@ -1144,26 +1148,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        return "parse_str_${containerName}_$memberName"
    }

    private fun getContentTypeCheck(): String {
        when (codegenContext.protocol) {
            RestJson1Trait.ID -> {
                return "check_rest_json_1_content_type"
            }
            RestXmlTrait.ID -> {
                return "check_rest_xml_content_type"
            }
            AwsJson1_0Trait.ID -> {
                return "check_aws_json_10_content_type"
            }
            AwsJson1_1Trait.ID -> {
                return "check_aws_json_11_content_type"
            }
            else -> {
                TODO("Protocol ${codegenContext.protocol} not supported yet")
            }
        }
    }

    /**
     * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the
     * shape targeted by the `httpPayload` trait.
+1 −1
Original line number Diff line number Diff line
@@ -25,8 +25,8 @@ http-body = "0.4"
hyper = { version = "0.14.12", features = ["server", "http1", "http2", "tcp", "stream"] }
mime = "0.3"
nom = "7"
paste = "1"
pin-project-lite = "0.2"
once_cell = "1.13"
regex = "1.5.5"
serde_urlencoded = "0.7"
strum_macros = "0.24"
+67 −163
Original line number Diff line number Diff line
@@ -4,9 +4,8 @@
 */

//! Protocol helpers.
use crate::rejection::RequestRejection;
use crate::rejection::MissingContentTypeReason;
use crate::request::RequestParts;
use paste::paste;

/// Supported protocols.
#[derive(Debug, Clone, Copy)]
@@ -17,58 +16,28 @@ pub enum Protocol {
    AwsJson11,
}

/// Implement the content-type header validation for a request.
macro_rules! impl_content_type_validation {
    ($name:literal, $type: literal, $subtype:literal, $rejection:path) => {
        paste! {
            #[doc = concat!("Validates that the request has the standard `", $type, "/", $subtype, "` content-type header.")]
            pub fn [<check_ $name _content_type>]<B>(req: &RequestParts<B>) -> Result<(), RequestRejection> {
                let mime = req
pub fn check_content_type<B>(
    req: &RequestParts<B>,
    expected_mime: &'static mime::Mime,
) -> Result<(), MissingContentTypeReason> {
    let found_mime = req
        .headers()
                    .ok_or($rejection)?
        .ok_or(MissingContentTypeReason::HeadersTakenByAnotherExtractor)?
        .get(http::header::CONTENT_TYPE)
                    .ok_or($rejection)?
        .ok_or(MissingContentTypeReason::NoContentTypeHeader)?
        .to_str()
                    .map_err(|_| $rejection)?
        .map_err(MissingContentTypeReason::ToStrError)?
        .parse::<mime::Mime>()
                    .map_err(|_| RequestRejection::MimeParse)?;
                if mime.type_() == $type && mime.subtype() == $subtype {
        .map_err(MissingContentTypeReason::MimeParseError)?;
    if &found_mime == expected_mime {
        Ok(())
    } else {
                    Err($rejection)
        Err(MissingContentTypeReason::UnexpectedMimeType {
            expected_mime,
            found_mime,
        })
    }
}
        }
    };
}

impl_content_type_validation!(
    "rest_json_1",
    "application",
    "json",
    RequestRejection::MissingRestJson1ContentType
);

impl_content_type_validation!(
    "rest_xml",
    "application",
    "xml",
    RequestRejection::MissingRestXmlContentType
);

impl_content_type_validation!(
    "aws_json_10",
    "application",
    "x-amz-json-1.0",
    RequestRejection::MissingAwsJson10ContentType
);

impl_content_type_validation!(
    "aws_json_11",
    "application",
    "x-amz-json-1.1",
    RequestRejection::MissingAwsJson11ContentType
);

#[cfg(test)]
mod tests {
@@ -83,129 +52,64 @@ mod tests {
        RequestParts::new(request)
    }

    /// This macro validates the rejection type since we cannot implement `PartialEq`
    /// for `RequestRejection` as it is based on the crate error type, which uses
    /// `crate::error::BoxError`.
    macro_rules! validate_rejection_type {
        ($result:expr, $rejection:path) => {
            match $result {
                Ok(()) => panic!("Content-type validation is expected to fail"),
                Err(e) => match e {
                    $rejection => {}
                    _ => panic!("Error {} should be {}", e.to_string(), stringify!($rejection)),
                },
            }
        };
    }
    static EXPECTED_MIME_APPLICATION_JSON: once_cell::sync::Lazy<mime::Mime> =
        once_cell::sync::Lazy::new(|| "application/json".parse::<mime::Mime>().unwrap());

    #[test]
    fn validate_rest_json_1_content_type() {
        // Check valid content-type header.
        let request = req("application/json");
        assert!(check_rest_json_1_content_type(&request).is_ok());

        // Check invalid content-type header.
        let invalid = vec![
            req("application/ajson"),
            req("application/json1"),
            req("applicatio/json"),
            req("application/xml"),
            req("text/xml"),
            req("application/x-amz-json-1.0"),
            req("application/x-amz-json-1.1"),
            RequestParts::new(Request::builder().body("").unwrap()),
        ];
        for request in &invalid {
            validate_rejection_type!(
                check_rest_json_1_content_type(request),
                RequestRejection::MissingRestJson1ContentType
            );
        }

        // Check request with not parsable content-type header.
        validate_rejection_type!(check_rest_json_1_content_type(&req("123")), RequestRejection::MimeParse);
    fn check_valid_content_type() {
        let valid_request = req("application/json");
        assert!(check_content_type(&valid_request, &EXPECTED_MIME_APPLICATION_JSON).is_ok());
    }

    #[test]
    fn validate_rest_xml_content_type() {
        // Check valid content-type header.
        let request = req("application/xml");
        assert!(check_rest_xml_content_type(&request).is_ok());

        // Check invalid content-type header.
        let invalid = vec![
            req("application/axml"),
            req("application/xml1"),
            req("applicatio/xml"),
            req("text/xml"),
            req("application/x-amz-json-1.0"),
            req("application/x-amz-json-1.1"),
            RequestParts::new(Request::builder().body("").unwrap()),
        ];
        for request in &invalid {
            validate_rejection_type!(
                check_rest_xml_content_type(request),
                RequestRejection::MissingRestXmlContentType
            );
    fn check_invalid_content_type() {
        let invalid = vec!["application/ajson", "text/xml"];
        for invalid_mime in invalid {
            let request = req(invalid_mime);
            let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);

            // Validates the rejection type since we cannot implement `PartialEq`
            // for `MissingContentTypeReason`.
            match result {
                Ok(()) => panic!("Content-type validation is expected to fail"),
                Err(e) => match e {
                    MissingContentTypeReason::UnexpectedMimeType {
                        expected_mime,
                        found_mime,
                    } => {
                        assert_eq!(expected_mime, &"application/json".parse::<mime::Mime>().unwrap());
                        assert_eq!(found_mime, invalid_mime);
                    }
                    _ => panic!("Unexpected `MissingContentTypeReason`: {}", e.to_string()),
                },
            }

        // Check request with not parsable content-type header.
        validate_rejection_type!(check_rest_xml_content_type(&req("123")), RequestRejection::MimeParse);
        }

    #[test]
    fn validate_aws_json_10_content_type() {
        // Check valid content-type header.
        let request = req("application/x-amz-json-1.0");
        assert!(check_aws_json_10_content_type(&request).is_ok());

        // Check invalid content-type header.
        let invalid = vec![
            req("application/x-amz-json-1."),
            req("application/-amz-json-1.0"),
            req("application/xml"),
            req("application/json"),
            req("applicatio/x-amz-json-1.0"),
            req("text/xml"),
            req("application/x-amz-json-1.1"),
            RequestParts::new(Request::builder().body("").unwrap()),
        ];
        for request in &invalid {
            validate_rejection_type!(
                check_aws_json_10_content_type(request),
                RequestRejection::MissingAwsJson10ContentType
            );
    }

        // Check request with not parsable content-type header.
        validate_rejection_type!(check_aws_json_10_content_type(&req("123")), RequestRejection::MimeParse);
    #[test]
    fn check_missing_content_type() {
        let request = RequestParts::new(Request::builder().body("").unwrap());
        let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
        assert!(matches!(
            result.unwrap_err(),
            MissingContentTypeReason::NoContentTypeHeader
        ));
    }

    #[test]
    fn validate_aws_json_11_content_type() {
        // Check valid content-type header.
        let request = req("application/x-amz-json-1.1");
        assert!(check_aws_json_11_content_type(&request).is_ok());

        // Check invalid content-type header.
        let invalid = vec![
            req("application/x-amz-json-1."),
            req("application/-amz-json-1.1"),
            req("application/xml"),
            req("application/json"),
            req("applicatio/x-amz-json-1.1"),
            req("text/xml"),
            req("application/x-amz-json-1.0"),
            RequestParts::new(Request::builder().body("").unwrap()),
        ];
        for request in &invalid {
            validate_rejection_type!(
                check_aws_json_11_content_type(request),
                RequestRejection::MissingAwsJson11ContentType
            );
    fn check_not_parsable_content_type() {
        let request = req("123");
        let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
        assert!(matches!(
            result.unwrap_err(),
            MissingContentTypeReason::MimeParseError(_)
        ));
    }

        // Check request with not parsable content-type header.
        validate_rejection_type!(check_aws_json_11_content_type(&req("123")), RequestRejection::MimeParse);
    #[test]
    fn check_non_ascii_visible_characters_content_type() {
        let request = req("application/💩");
        let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON);
        assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_)));
    }
}
+20 −6
Original line number Diff line number Diff line
@@ -133,12 +133,8 @@ pub enum RequestRejection {
    /// `hyper::body::to_bytes`.
    HttpBody(crate::Error),

    // These are used when checking the `Content-Type` header.
    MissingRestJson1ContentType,
    MissingAwsJson10ContentType,
    MissingAwsJson11ContentType,
    MissingRestXmlContentType,
    MimeParse,
    /// Used when checking the `Content-Type` header.
    MissingContentType(MissingContentTypeReason),

    /// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to
    /// the modeled input it should represent.
@@ -192,6 +188,18 @@ pub enum RequestRejection {
    EnumVariantNotFound(Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Debug, Display)]
pub enum MissingContentTypeReason {
    HeadersTakenByAnotherExtractor,
    NoContentTypeHeader,
    ToStrError(http::header::ToStrError),
    MimeParseError(mime::FromStrError),
    UnexpectedMimeType {
        expected_mime: &'static mime::Mime,
        found_mime: mime::Mime,
    },
}

impl std::error::Error for RequestRejection {}

// Consider a conversion between `T` and `U` followed by a bubbling up of the conversion error
@@ -213,6 +221,12 @@ impl From<std::convert::Infallible> for RequestRejection {
    }
}

impl From<MissingContentTypeReason> for RequestRejection {
    fn from(e: MissingContentTypeReason) -> Self {
        Self::MissingContentType(e)
    }
}

// These converters are solely to make code-generation simpler. They convert from a specific error
// type (from a runtime/third-party crate or the standard library) into a variant of the
// [`crate::rejection::RequestRejection`] enum holding the type-erased boxed [`crate::Error`]