Unverified Commit f324240b authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Unsupported content type (#1723)



Add validation for the Content-Type header and pass (remove from the failing
list) the relevant protocol tests

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent 23128431
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -90,6 +90,13 @@ interface Protocol {
     * protocol.
     */
    fun serverRouterRuntimeConstructor(): String

    /**
     * In some protocols, such as restJson1,
     * when there is no modeled body input, content type must not be set and the body must be empty.
     * Returns a boolean indicating whether to perform this check.
     */
    fun serverContentTypeCheckNoModeledInput(): Boolean = false
}

typealias ProtocolMap<C> = Map<ShapeId, ProtocolGeneratorFactory<ProtocolGenerator, C>>
+2 −0
Original line number Diff line number Diff line
@@ -150,6 +150,8 @@ open class RestJson(val coreCodegenContext: CoreCodegenContext) : Protocol {
    ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

    override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

    override fun serverContentTypeCheckNoModeledInput() = true
}

fun restJsonFieldName(member: MemberShape): String {
+2 −0
Original line number Diff line number Diff line
@@ -112,6 +112,8 @@ open class RestXml(val coreCodegenContext: CoreCodegenContext) : Protocol {
    ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape)

    override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

    override fun serverContentTypeCheckNoModeledInput() = true
}

/**
+0 −3
Original line number Diff line number Diff line
@@ -841,9 +841,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledContentType", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithoutBodyExpectsEmptyContentType", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedListNullItem", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedMapNullValue", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonMalformedSetDuplicateItems", TestType.MalformedRequest),
+45 −7
Original line number Diff line number Diff line
@@ -21,7 +21,9 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.rust.codegen.client.rustlang.Attribute
import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
@@ -29,6 +31,7 @@ import software.amazon.smithy.rust.codegen.client.rustlang.RustType
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.rustlang.Writable
import software.amazon.smithy.rust.codegen.client.rustlang.asType
import software.amazon.smithy.rust.codegen.client.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.client.rustlang.render
import software.amazon.smithy.rust.codegen.client.rustlang.rust
import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock
@@ -58,6 +61,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.client.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.client.smithy.toOptional
import software.amazon.smithy.rust.codegen.client.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.client.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.client.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.client.util.dq
@@ -65,6 +69,7 @@ import software.amazon.smithy.rust.codegen.client.util.expectTrait
import software.amazon.smithy.rust.codegen.client.util.findStreamingMember
import software.amazon.smithy.rust.codegen.client.util.getTrait
import software.amazon.smithy.rust.codegen.client.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.client.util.hasTrait
import software.amazon.smithy.rust.codegen.client.util.inputShape
import software.amazon.smithy.rust.codegen.client.util.isStreaming
import software.amazon.smithy.rust.codegen.client.util.outputShape
@@ -168,7 +173,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        val operationName = symbolProvider.toSymbol(operationShape).name
        val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

        val verifyResponseContentType = writable {
        val verifyAcceptHeader = writable {
            httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
                rustTemplate(
                    """
@@ -183,6 +188,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                )
            }
        }
        val verifyRequestContentTypeHeader = writable {
            operationShape
                .inputShape(model)
                .members()
                .find { it.hasTrait<HttpPayloadTrait>() }
                ?.let { payload ->
                    val target = model.expectShape(payload.target)
                    if (!target.isBlobShape || target.hasTrait<MediaTypeTrait>()) {
                        val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
                            ?.let { "Some(${it.dq()})" } ?: "None"
                        rustTemplate(
                            """
                            if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() {
                                return Err(#{RuntimeError} {
                                    protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
                                    kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType,
                                })
                            }
                            """,
                            *codegenScope,
                        )
                    }
                }
        }

        // Implement `from_request` trait for input types.
        rustTemplate(
@@ -197,7 +226,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                    B::Data: Send,
                    #{RequestRejection} : From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
                {
                    #{verify_response_content_type:W}
                    #{verifyAcceptHeader:W}
                    #{verifyRequestContentTypeHeader:W}
                    #{parse_request}(req)
                        .await
                        .map($inputName)
@@ -235,7 +265,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
            "I" to inputSymbol,
            "Marker" to serverProtocol.markerStruct(),
            "parse_request" to serverParseRequest(operationShape),
            "verify_response_content_type" to verifyResponseContentType,
            "verifyAcceptHeader" to verifyAcceptHeader,
            "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader,
        )

        // Implement `into_response` for output types.
@@ -711,16 +742,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        Attribute.AllowUnusedMut.render(this)
        rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
        val parser = structuredDataParser.serverInputParser(operationShape)
        val noInputs = model.expectShape(operationShape.inputShape).expectTrait<SyntheticInputTrait>().originalId == null
        if (parser != null) {
            val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape)
            rustTemplate(
                """
                let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?;
                let bytes = #{Hyper}::body::to_bytes(body).await?;
                if !bytes.is_empty() {
                    static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> =
                        #{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap());
                    #{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?;
                    input = #{parser}(bytes.as_ref(), input)?;
                }
                """,
@@ -740,6 +768,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        serverRenderUriPathParser(this, operationShape)
        serverRenderQueryStringParser(this, operationShape)

        if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) {
            conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) {
                rustTemplate(
                    """
                    #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?;
                    """,
                    *codegenScope,
                )
            }
        }
        val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) {
            "?"
        } else ""
Loading