Unverified Commit bc4c1862 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Lazy initialize mime accept header (#2629)

## Description

Do not parse and initialize the mime, known at compile time, on every
request.

See
https://github.com/awslabs/smithy-rs/pull/2607#discussion_r1172607704




## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent 168184e5
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -178,11 +178,13 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        outputSymbol: Symbol,
        operationShape: OperationShape,
    ) {
        val operationName = symbolProvider.toSymbol(operationShape).name
        val staticContentType = "CONTENT_TYPE_${operationName.uppercase()}"
        val verifyAcceptHeader = writable {
            httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
                rustTemplate(
                    """
                    if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) {
                    if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), &$staticContentType) {
                        return Err(#{RequestRejection}::NotAcceptable);
                    }
                    """,
@@ -190,6 +192,22 @@ class ServerHttpBoundProtocolTraitImplGenerator(
                )
            }
        }
        val verifyAcceptHeaderStaticContentTypeInit = writable {
            httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
                val init = when (contentType) {
                    "application/json" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;"
                    "application/octet-stream" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;"
                    "application/x-www-form-urlencoded" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;"
                    else ->
                        """
                    static $staticContentType: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = #{OnceCell}::sync::Lazy::new(|| {
                        ${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid")
                    });
                    """
                }
                rustTemplate(init, *codegenScope)
            }
        }
        val verifyRequestContentTypeHeader = writable {
            operationShape
                .inputShape(model)
@@ -215,6 +233,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        // TODO(https://github.com/awslabs/smithy-rs/issues/2238): Remove the `Pin<Box<dyn Future>>` and replace with thin wrapper around `Collect`.
        rustTemplate(
            """
            #{verifyAcceptHeaderStaticContentTypeInit:W}
            #{PinProjectLite}::pin_project! {
                /// A [`Future`](std::future::Future) aggregating the body bytes of a [`Request`] and constructing the
                /// [`${inputSymbol.name}`](#{I}) using modelled bindings.
@@ -267,6 +286,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
            "Marker" to protocol.markerStruct(),
            "parse_request" to serverParseRequest(operationShape),
            "verifyAcceptHeader" to verifyAcceptHeader,
            "verifyAcceptHeaderStaticContentTypeInit" to verifyAcceptHeaderStaticContentTypeInit,
            "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader,
        )

+29 −12
Original line number Diff line number Diff line
@@ -66,14 +66,10 @@ pub fn content_type_header_classifier(
    Ok(())
}

pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool {
pub fn accept_header_classifier(headers: &HeaderMap, content_type: &mime::Mime) -> bool {
    if !headers.contains_key(http::header::ACCEPT) {
        return true;
    }
    // Must be of the form: type/subtype
    let content_type = content_type
        .parse::<mime::Mime>()
        .expect("BUG: MIME parsing failed, content_type is not valid");
    headers
        .get_all(http::header::ACCEPT)
        .into_iter()
@@ -195,41 +191,62 @@ mod tests {
    #[test]
    fn valid_accept_header_classifier_multiple_values() {
        let valid_request = req_accept("text/strings, application/json, invalid");
        assert!(accept_header_classifier(&valid_request, "application/json"));
        assert!(accept_header_classifier(
            &valid_request,
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn invalid_accept_header_classifier() {
        let invalid_request = req_accept("text/invalid, invalid, invalid/invalid");
        assert!(!accept_header_classifier(&invalid_request, "application/json"));
        assert!(!accept_header_classifier(
            &invalid_request,
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn valid_accept_header_classifier_star() {
        let valid_request = req_accept("application/*");
        assert!(accept_header_classifier(&valid_request, "application/json"));
        assert!(accept_header_classifier(
            &valid_request,
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn valid_accept_header_classifier_star_star() {
        let valid_request = req_accept("*/*");
        assert!(accept_header_classifier(&valid_request, "application/json"));
        assert!(accept_header_classifier(
            &valid_request,
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn valid_empty_accept_header_classifier() {
        assert!(accept_header_classifier(&HeaderMap::new(), "application/json"));
        assert!(accept_header_classifier(
            &HeaderMap::new(),
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn valid_accept_header_classifier_with_params() {
        let valid_request = req_accept("application/json; q=30, */*");
        assert!(accept_header_classifier(&valid_request, "application/json"));
        assert!(accept_header_classifier(
            &valid_request,
            &"application/json".parse().unwrap()
        ));
    }

    #[test]
    fn valid_accept_header_classifier() {
        let valid_request = req_accept("application/json");
        assert!(accept_header_classifier(&valid_request, "application/json"));
        assert!(accept_header_classifier(
            &valid_request,
            &"application/json".parse().unwrap()
        ));
    }
}