Unverified Commit fa54d8b6 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Check `Content-Type` header in all server protocols (#2531)

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [ ] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 791e8d41
Loading
Loading
Loading
Loading
+24 −21
Original line number Diff line number Diff line
@@ -57,7 +57,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtoc
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
@@ -234,6 +233,11 @@ class ServerHttpBoundProtocolTraitImplGenerator(
                rustTemplate(init, *codegenScope)
            }
        }
        // This checks for the expected `Content-Type` header if the `@httpPayload` trait is present, as dictated by
        // the core Smithy library, which _does not_ require deserializing the payload.
        // If no members have `@httpPayload`, the expected `Content-Type` header as dictated _by the protocol_ is
        // checked later on for non-streaming operations, in `serverRenderShapeParser`: that check _does_ require at
        // least buffering the entire payload, since the check must only be performed if the payload is empty.
        val verifyRequestContentTypeHeader = writable {
            operationShape
                .inputShape(model)
@@ -242,11 +246,15 @@ class ServerHttpBoundProtocolTraitImplGenerator(
                ?.let { payload ->
                    val target = model.expectShape(payload.target)
                    if (!target.isBlobShape || target.hasTrait<MediaTypeTrait>()) {
                        val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
                            ?.let { "Some(${it.dq()})" } ?: "None"
                        // `null` is only returned by Smithy when there are no members, but we know there's at least
                        // the one with `@httpPayload`, so `!!` is safe here.
                        val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!!
                        rustTemplate(
                            """
                            #{SmithyHttpServer}::protocol::content_type_header_classifier(request.headers(), $expectedRequestContentType)?;
                            #{SmithyHttpServer}::protocol::content_type_header_classifier(
                                request.headers(),
                                Some("$expectedRequestContentType"),
                            )?;
                            """,
                            *codegenScope,
                        )
@@ -689,31 +697,26 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        rust("let (parts, body) = request.into_parts();")
        val parser = structuredDataParser.serverInputParser(operationShape)
        val noInputs = model.expectShape(operationShape.inputShape).expectTrait<SyntheticInputTrait>().originalId == null

        if (parser != null) {
            // `null` is only returned by Smithy when there are no members, but we know there's at least one, since
            // there's something to parse (i.e. `parser != null`), so `!!` is safe here.
            val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)!!
            rustTemplate("let bytes = #{Hyper}::body::to_bytes(body).await?;", *codegenScope)
            rustBlock("if !bytes.is_empty()") {
                rustTemplate(
                    """
                let bytes = #{Hyper}::body::to_bytes(body).await?;
                if !bytes.is_empty() {
                """,
                *codegenScope,
            )
            if (protocol is RestJson) {
                rustTemplate(
                    """
                    #{SmithyHttpServer}::protocol::content_type_header_classifier(&parts.headers, Some("application/json"))?;
                    """,
                    *codegenScope,
                )
            }
            rustTemplate(
                """
                    #{SmithyHttpServer}::protocol::content_type_header_classifier(
                        &parts.headers, 
                        Some("$expectedRequestContentType"),
                    )?;
                    input = #{parser}(bytes.as_ref(), input)?;
                }
                    """,
                    *codegenScope,
                    "parser" to parser,
                )
            }
        }
        for (binding in bindings) {
            val member = binding.member
            val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
+2 −0
Original line number Diff line number Diff line
@@ -115,6 +115,8 @@ pub enum RequestRejection {
    NotAcceptable,

    /// Used when checking the `Content-Type` header.
    /// This is bubbled up in the generated SDK when calling
    /// [`crate::protocol::content_type_header_classifier`] in `from_request`.
    #[error("expected `Content-Type` header not found: {0}")]
    MissingContentType(#[from] MissingContentTypeReason),

+3 −0
Original line number Diff line number Diff line
@@ -47,6 +47,9 @@ pub enum RuntimeError {
    InternalFailure(crate::Error),
    /// Request contained an `Accept` header with a MIME type, and the server cannot return a response
    /// body adhering to that MIME type.
    /// This is returned directly (i.e. without going through a [`RequestRejection`] first) in the
    /// generated SDK when calling [`crate::protocol::accept_header_classifier`] in
    /// `from_request`.
    NotAcceptable,
    /// The request does not contain the expected `Content-Type` header value.
    UnsupportedMediaType,