Unverified Commit d48c2347 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

`rust-server-codegen`: add `ResponseExtensions` to non-fallible operations (#1072)

`rust-server-codegen`: add `ResponseExtensions` to non-fallible operations

This commit also renames `RequestExtensions` to `ResponseExtensions`,
since we add this extension type only to HTTP responses. It also makes
its members private.

Closes #1063.
parent dac48166
Loading
Loading
Loading
Loading
+13 −22
Original line number Diff line number Diff line
@@ -311,7 +311,7 @@ class ServerProtocolTestGenerator(
        checkHeaders(this, "&http_response.headers()", testCase.headers)
        checkForbidHeaders(this, "&http_response.headers()", testCase.forbidHeaders)
        checkRequiredHeaders(this, "&http_response.headers()", testCase.requireHeaders)
        checkHttpExtensions(this)
        checkHttpResponseExtensions(this)
        if (!testCase.body.isEmpty) {
            rustTemplate(
                """
@@ -341,15 +341,20 @@ class ServerProtocolTestGenerator(
        }
    }

    private fun checkHttpExtensions(rustWriter: RustWriter) {
    private fun checkHttpResponseExtensions(rustWriter: RustWriter) {
        rustWriter.rustTemplate(
            """
            let request_extensions = http_response.extensions().get::<aws_smithy_http_server::RequestExtensions>().expect("extension `RequestExtensions` not found");
            #{AssertEq}(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
            #{AssertEq}(request_extensions.operation_name, ${operationSymbol.name.dq()});
            let response_extensions = http_response.extensions()
                .get::<#{SmithyHttpServer}::ResponseExtensions>()
                .expect("extension `ResponseExtensions` not found");
            """.trimIndent(),
            *codegenScope
        )
        rustWriter.writeWithNoFormatting(
            """
            assert_eq!(response_extensions.operation(), format!("{}#{}", "${operationShape.id.namespace}", "${operationSymbol.name}"));
            """.trimIndent()
        )
    }

    private fun checkRequiredHeaders(rustWriter: RustWriter, actualExpression: String, requireHeaders: List<String>) {
@@ -486,21 +491,9 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Request),
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersAreNotPresent", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
            FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Response),
            FailingTest(RestJson, "RestJsonNullAndEmptyHeaders", Action.Response),
            FailingTest(RestJson, "MediaTypeHeaderOutputBase64", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithEnumHeaders", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderOutputs", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithNumericHeaders", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithBooleanHeaders", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithTimestampHeaders", Action.Response),
            FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatHeaderOutputs", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsInfinityFloatHeaderOutputs", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),
            FailingTest(RestJson, "RestJsonTimestampFormatHeaders", Action.Response),

            FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithBlob", Action.Request),
            FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
@@ -513,7 +506,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "DocumentOutputArray", Action.Response),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response),
            FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
            FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestJson, "RestJsonInvalidGreetingError", Action.Response),
@@ -534,7 +526,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
            FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
            FailingTest(RestJson, "StringPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonIgnoreQueryParamsInResponse", Action.Response),
            FailingTest(RestJson, "RestJsonJsonBlobs", Action.Response),
            FailingTest(RestJson, "RestJsonJsonEnums", Action.Response),
            FailingTest(RestJson, "RestJsonLists", Action.Response),
@@ -563,9 +554,8 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
            FailingTest(RestJson, "RestJsonRecursiveShapes", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
            FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Response),
            FailingTest(RestJson, "RestJsonServersDontSerializeNullStructureValues", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Response),
            FailingTest(RestJson, "RestJsonSimpleScalarProperties", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsInfinityFloatInputs", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request),
@@ -584,6 +574,7 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonTestPayloadBlob", Action.Request),
            FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),
            FailingTest(RestJson, "RestJsonTestPayloadStructure", Action.Request),

            FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", Action.Response),
            FailingTest("com.amazonaws.s3#AmazonS3", "S3DefaultAddressing", Action.Request),
            FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAddressing", Action.Request),
+7 −5
Original line number Diff line number Diff line
@@ -250,10 +250,12 @@ private class ServerHttpProtocolImplGenerator(
                intoResponseStreaming
            } else {
                """
                match #{serialize_response}(&self.0) {
                let mut response = match #{serialize_response}(&self.0) {
                    Ok(response) => response,
                    Err(e) => e.into_response()
                }
                };
                $httpExtensions
                response
                """.trimIndent()
            }
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
@@ -317,13 +319,13 @@ private class ServerHttpProtocolImplGenerator(
    }

    /*
     * Set `http::Extensions` for the current request. They can be used later for things like metrics, logging, etc..
     * Set `http::Extensions` for the current request. They can be used later for things like metrics, logging...
     */
    private fun setHttpExtensions(operationShape: OperationShape): String {
        val namespace = operationShape.id.getNamespace()
        val namespace = operationShape.id.namespace
        val operationName = symbolProvider.toSymbol(operationShape).name
        return """
            response.extensions_mut().insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()}));
            response.extensions_mut().insert(#{SmithyHttpServer}::ResponseExtensions::new("$namespace", "$operationName"));
        """.trimIndent()
    }

+8 −8
Original line number Diff line number Diff line
@@ -39,17 +39,17 @@ use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use std::ops::Deref;

/// Extension type used to store Smithy request information.
#[derive(Debug, Clone, Default, Copy)]
pub struct RequestExtensions {
/// Extension type used to store information in HTTP responses.
#[derive(Debug, Clone)]
pub struct ResponseExtensions {
    /// Smithy model namespace.
    pub namespace: &'static str,
    namespace: &'static str,
    /// Smithy operation name.
    pub operation_name: &'static str,
    operation_name: &'static str,
}

impl RequestExtensions {
    /// Generates a new `RequestExtensions`.
impl ResponseExtensions {
    /// Creates a new `ResponseExtensions`.
    pub fn new(namespace: &'static str, operation_name: &'static str) -> Self {
        Self {
            namespace,
@@ -57,7 +57,7 @@ impl RequestExtensions {
        }
    }

    /// Returns the current operation formatted as <namespace>#<operation_name>.
    /// Returns the current operation formatted as `<namespace>#<operation_name>`.
    pub fn operation(&self) -> String {
        format!("{}#{}", self.namespace, self.operation_name)
    }
+1 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody};
#[doc(inline)]
pub use self::error::Error;
#[doc(inline)]
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions};
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, ResponseExtensions};
#[doc(inline)]
pub use self::routing::Router;
#[doc(inline)]