Unverified Commit c78c67e3 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Comply with content types for the accept header (#1424)

* Verify accept header for payloads with `@mediaType`

Comply with content-types as described in the documentation [0]

[0] https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html#content-type



Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.com>
parent b7506ec2
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -661,9 +661,7 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", TestType.Request),
            FailingTest(RestJson, "RestJsonHttpResponseCodeDefaultsToModeledCode", TestType.Response),

            FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonAccept", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedAccept", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledAccept", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyByteMalformedValueRejected_case2", TestType.MalformedRequest),
+24 −2
Original line number Diff line number Diff line
@@ -159,6 +159,26 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
        val operationName = symbolProvider.toSymbol(operationShape).name
        val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

        val verifyResponseContentType = writable {
            httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
                rustTemplate(
                    """
                    if let Some(headers) = req.headers() {
                        if let Some(accept) = headers.get(#{http}::header::ACCEPT) {
                            if accept != "$contentType" {
                                return Err(Self::Rejection {
                                    protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
                                    kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable,
                                })
                            }
                        }
                    }
                    """,
                    *codegenScope,
                )
            }
        }

        // Implement `FromRequest` trait for input types.
        rustTemplate(
            """
@@ -173,6 +193,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
            {
                type Rejection = #{RuntimeError};
                async fn from_request(req: &mut #{SmithyHttpServer}::request::RequestParts<B>) -> Result<Self, Self::Rejection> {
                    #{verify_response_content_type:W}
                    #{parse_request}(req)
                        .await
                        .map($inputName)
@@ -187,7 +208,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
            """.trimIndent(),
            *codegenScope,
            "I" to inputSymbol,
            "parse_request" to serverParseRequest(operationShape)
            "parse_request" to serverParseRequest(operationShape),
            "verify_response_content_type" to verifyResponseContentType,
        )

        // Implement `IntoResponse` for output types.
@@ -227,7 +249,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
                        }
                    }
                }
                """.trimIndent()
                """

            rustTemplate(
                """
+22 −3
Original line number Diff line number Diff line
@@ -9,7 +9,10 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
@@ -23,6 +26,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.outputShape

class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
    override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext)
@@ -56,15 +61,29 @@ class RestJsonFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> {
 * `application/json` if not overridden.
 */
class RestJsonHttpBindingResolver(
    model: Model,
    private val model: Model,
    contentTypes: ProtocolContentTypes,
) : HttpTraitHttpBindingResolver(model, contentTypes) {
    /**
     * In the RestJson1 protocol, HTTP responses have a default `Content-Type: application/json` header if it is not
     * overridden by a specific mechanism e.g. an output shape member is targeted with `httpPayload` or `mediaType` traits.
     */
    override fun responseContentType(operationShape: OperationShape): String =
        super.responseContentType(operationShape) ?: "application/json"
    override fun responseContentType(operationShape: OperationShape): String? {
        val members = operationShape
            .outputShape(model)
            .members()
        // TODO(https://github.com/awslabs/smithy/issues/1259)
        //  Temporary fix for https://github.com/awslabs/smithy/blob/df456a514f72f4e35f0fb07c7e26006ff03b2071/smithy-model/src/main/java/software/amazon/smithy/model/knowledge/HttpBindingIndex.java#L352
        for (member in members) {
            if (member.hasTrait<HttpPayloadTrait>()) {
                val target = model.expectShape(member.target)
                if (!target.hasTrait<StreamingTrait>() && !target.hasTrait<MediaTypeTrait>() && target.isBlobShape) {
                    return null
                }
            }
        }
        return super.responseContentType(operationShape) ?: "application/json"
    }
}

class RestJson(private val codegenContext: CodegenContext) : Protocol {
+4 −2
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ pub enum RuntimeErrorKind {
    /// [`crate::extension::Extension`] from the request.
    InternalFailure(crate::Error),
    // UnsupportedMediaType,
    // NotAcceptable,
    NotAcceptable,
}

/// String representation of the runtime error type.
@@ -47,7 +47,8 @@ impl RuntimeErrorKind {
        match self {
            RuntimeErrorKind::Serialization(_) => "SerializationException",
            RuntimeErrorKind::InternalFailure(_) => "InternalFailureException",
            RuntimeErrorKind::UnknownOperation => "UnknownOperation",
            RuntimeErrorKind::UnknownOperation => "UnknownOperationException",
            RuntimeErrorKind::NotAcceptable => "NotAcceptableException",
        }
    }
}
@@ -64,6 +65,7 @@ impl IntoResponse for RuntimeError {
            RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,
            RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR,
            RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND,
            RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE,
        };

        let body = crate::body::to_boxed(match self.protocol {