Unverified Commit e6600aa8 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Add operation input/output wrapper conversion functions (#863)

The functions allow to wrap the model types into the wrappers and unwrap
the wrappers into the model types.

The wrapper types have been made private.
parent ada03d4c
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -248,7 +248,7 @@ class ServerProtocolTestGenerator(
        }
        rustTemplate(
            """
            let output = super::$operationName(output);
            let output = super::$operationName::Output(output);
            use #{Axum}::response::IntoResponse;
            let http_response = output.into_response();
            """,
+98 −40
Original line number Diff line number Diff line
@@ -86,27 +86,22 @@ private class ServerHttpProtocolImplGenerator(
    private val logger = Logger.getLogger(javaClass.name)
    private val symbolProvider = codegenContext.symbolProvider
    private val model = codegenContext.model
    private val errorType = RuntimeType("error", null, "crate")
    private val runtimeConfig = codegenContext.runtimeConfig
    private val httpBindingResolver = protocol.httpBindingResolver
    private val operationDeserModule = RustModule.private("operation_deser")
    private val operationSerModule = RustModule.private("operation_ser")
    private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()

    private val codegenScope = arrayOf(
        "JsonObjectWriter" to smithyJson.member("serialize::JsonObjectWriter"),
        "http" to RuntimeType.http,
        "Bytes" to RuntimeType.Bytes,
        "LazyStatic" to CargoDependency.LazyStatic.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Axum" to CargoDependency.Axum.asType(),
        "DateTime" to RuntimeType.DateTime(runtimeConfig),
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
        "LazyStatic" to CargoDependency.LazyStatic.asType(),
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "SdkBody" to RuntimeType.sdkBody(runtimeConfig),
        "DateTime" to RuntimeType.DateTime(runtimeConfig)
        "http" to RuntimeType.http,
    )

    override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
@@ -140,12 +135,11 @@ private class ServerHttpProtocolImplGenerator(
        outputSymbol: Symbol,
        operationShape: OperationShape
    ) {
        val errorSymbol = operationShape.errorSymbol(symbolProvider)
        // Implement Axum `FromRequest` trait for non streaming input types.
        val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
        rustTemplate(
            """
            pub struct $inputName(#{I});
            struct $inputName(#{I});
            ##[#{Axum}::async_trait]
            impl<B> #{Axum}::extract::FromRequest<B> for $inputName
            where
@@ -159,55 +153,119 @@ private class ServerHttpProtocolImplGenerator(
                    #{SmithyHttpServer}::protocols::check_json_content_type(req)?;
                    Ok($inputName(#{parse_request}(req).await?))
                }
            }""",
            }""".trimIndent(),
            *codegenScope,
            "I" to inputSymbol,
            "parse_request" to serverParseRequest(operationShape)
        )

        // Implement Axum `IntoResponse` for non streaming output types.
        val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        val errorSymbol = operationShape.errorSymbol(symbolProvider)

        val handleSerializeOutput = """
            Ok(response) => response,
            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
        """.trimIndent()
        if (operationShape.errors.isNotEmpty()) {
            // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
            // that can in turn be converted into a response.
            rustTemplate(
                """
            pub struct $outputName(#{O});
                enum $outputName {
                    Output(#{O}),
                    Error(#{E})
                }
                ##[#{Axum}::async_trait]
                impl #{Axum}::response::IntoResponse for $outputName {
                    type Body = #{SmithyHttpServer}::Body;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;

                    fn into_response(self) -> #{http}::Response<Self::Body> {
                    match #{serialize_response}(&self.0) {
                        match self {
                            Self::Output(o) => {
                                match #{serialize_response}(&o) {
                                    $handleSerializeOutput
                                }
                            },
                            Self::Error(err) => {
                                match #{serialize_error}(&err) {
                                    Ok(response) => response,
                                    Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                                }
                            }
            }""",
                        }
                    }
                }""".trimIndent(),
                *codegenScope,
                "O" to outputSymbol,
            "serialize_response" to serverSerializeResponse(operationShape)
                "E" to errorSymbol,
                "serialize_response" to serverSerializeResponse(operationShape),
                "serialize_error" to serverSerializeError(operationShape)
            )
        if (operationShape.errors.isNotEmpty()) {
            // Implement Axum `IntoResponse` for non streaming error types.
            val errorName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_ERROR_WRAPPER_SUFFIX}"
        } else {
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
            // we control that can in turn be converted into a response.
            rustTemplate(
                """
                pub struct $errorName(#{E});
                struct $outputName(#{O});
                ##[#{Axum}::async_trait]
                impl #{Axum}::response::IntoResponse for $errorName {
                impl #{Axum}::response::IntoResponse for $outputName {
                    type Body = #{SmithyHttpServer}::Body;
                    type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;

                    fn into_response(self) -> #{http}::Response<Self::Body> {
                        match #{serialize_error}(&self.0) {
                            Ok(response) => response,
                            Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
                        match #{serialize_response}(&self.0) {
                            $handleSerializeOutput
                        }
                    }
                }""",
                }""".trimIndent(),
                *codegenScope,
                "E" to errorSymbol,
                "serialize_error" to serverSerializeError(operationShape)
                "O" to outputSymbol,
                "serialize_response" to serverSerializeResponse(operationShape)
            )
        }

        // Implement conversion function to "wrap" from the model operation output types.
        if (operationShape.errors.isNotEmpty()) {
            rustTemplate(
                """
                impl From<Result<#{O}, #{E}>> for $outputName {
                    fn from(res: Result<#{O}, #{E}>) -> Self {
                        match res {
                            Ok(v) => Self::Output(v),
                            Err(e) => Self::Error(e),
                        }
                    }
                }
                """.trimIndent(),
                "O" to outputSymbol,
                "E" to errorSymbol
            )
        } else {
            rustTemplate(
                """
                impl From<#{O}> for $outputName {
                    fn from(o: #{O}) -> Self {
                        Self(o)
                    }
                }
                """.trimIndent(),
                "O" to outputSymbol
            )
        }

        // Implement conversion function to "unwrap" into the model operation input types.
        rustTemplate(
            """
            impl From<$inputName> for #{I} {
                fn from(i: $inputName) -> Self {
                    i.0
                }
            }
            """.trimIndent(),
            "I" to inputSymbol
        )
    }

    /*