From e6600aa805cfcd033162bb69546c75e935497a56 Mon Sep 17 00:00:00 2001 From: david-perez Date: Mon, 15 Nov 2021 13:37:48 +0100 Subject: [PATCH] Add operation input/output wrapper conversion functions (#863) The functions allow to wrap the model types into the wrappers and unwrap the wrappers into the model types. The wrapper types have been made private. --- .../protocol/ServerProtocolTestGenerator.kt | 2 +- .../protocols/ServerHttpProtocolGenerator.kt | 138 +++++++++++++----- 2 files changed, 99 insertions(+), 41 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index e70897ca5..2a4425386 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -248,7 +248,7 @@ class ServerProtocolTestGenerator( } rustTemplate( """ - let output = super::$operationName(output); + let output = super::$operationName::Output(output); use #{Axum}::response::IntoResponse; let http_response = output.into_response(); """, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 9e68c0e81..de0d36a2d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -86,27 +86,22 @@ private class ServerHttpProtocolImplGenerator( private val logger = Logger.getLogger(javaClass.name) private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model - private val errorType = RuntimeType("error", null, "crate") private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver private val operationDeserModule = RustModule.private("operation_deser") private val operationSerModule = RustModule.private("operation_ser") - private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() private val codegenScope = arrayOf( - "JsonObjectWriter" to smithyJson.member("serialize::JsonObjectWriter"), - "http" to RuntimeType.http, - "Bytes" to RuntimeType.Bytes, - "LazyStatic" to CargoDependency.LazyStatic.asType(), - "Regex" to CargoDependency.Regex.asType(), - "PercentEncoding" to CargoDependency.PercentEncoding.asType(), "Axum" to CargoDependency.Axum.asType(), + "DateTime" to RuntimeType.DateTime(runtimeConfig), "HttpBody" to CargoDependency.HttpBody.asType(), "Hyper" to CargoDependency.Hyper.asType(), + "LazyStatic" to CargoDependency.LazyStatic.asType(), + "PercentEncoding" to CargoDependency.PercentEncoding.asType(), + "Regex" to CargoDependency.Regex.asType(), "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(), "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "DateTime" to RuntimeType.DateTime(runtimeConfig) + "http" to RuntimeType.http, ) override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) { @@ -140,12 +135,11 @@ private class ServerHttpProtocolImplGenerator( outputSymbol: Symbol, operationShape: OperationShape ) { - val errorSymbol = operationShape.errorSymbol(symbolProvider) // Implement Axum `FromRequest` trait for non streaming input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" rustTemplate( """ - pub struct $inputName(#{I}); + struct $inputName(#{I}); ##[#{Axum}::async_trait] impl #{Axum}::extract::FromRequest for $inputName where @@ -159,55 +153,119 @@ private class ServerHttpProtocolImplGenerator( #{SmithyHttpServer}::protocols::check_json_content_type(req)?; Ok($inputName(#{parse_request}(req).await?)) } - }""", + }""".trimIndent(), *codegenScope, "I" to inputSymbol, "parse_request" to serverParseRequest(operationShape) ) + // Implement Axum `IntoResponse` for non streaming output types. val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" - rustTemplate( - """ - pub struct $outputName(#{O}); - ##[#{Axum}::async_trait] - impl #{Axum}::response::IntoResponse for $outputName { - type Body = #{SmithyHttpServer}::Body; - type BodyError = ::Error; - - fn into_response(self) -> #{http}::Response { - match #{serialize_response}(&self.0) { - Ok(response) => response, - Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error") - } - } - }""", - *codegenScope, - "O" to outputSymbol, - "serialize_response" to serverSerializeResponse(operationShape) - ) + val errorSymbol = operationShape.errorSymbol(symbolProvider) + + val handleSerializeOutput = """ + Ok(response) => response, + Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output") + """.trimIndent() if (operationShape.errors.isNotEmpty()) { - // Implement Axum `IntoResponse` for non streaming error types. - val errorName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_ERROR_WRAPPER_SUFFIX}" + // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control + // that can in turn be converted into a response. rustTemplate( """ - pub struct $errorName(#{E}); + enum $outputName { + Output(#{O}), + Error(#{E}) + } ##[#{Axum}::async_trait] - impl #{Axum}::response::IntoResponse for $errorName { + impl #{Axum}::response::IntoResponse for $outputName { type Body = #{SmithyHttpServer}::Body; type BodyError = ::Error; fn into_response(self) -> #{http}::Response { - match #{serialize_error}(&self.0) { - Ok(response) => response, - Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error") + match self { + Self::Output(o) => { + match #{serialize_response}(&o) { + $handleSerializeOutput + } + }, + Self::Error(err) => { + match #{serialize_error}(&err) { + Ok(response) => response, + Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error") + } + } } } - }""", + }""".trimIndent(), *codegenScope, + "O" to outputSymbol, "E" to errorSymbol, + "serialize_response" to serverSerializeResponse(operationShape), "serialize_error" to serverSerializeError(operationShape) ) + } else { + // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type + // we control that can in turn be converted into a response. + rustTemplate( + """ + struct $outputName(#{O}); + ##[#{Axum}::async_trait] + impl #{Axum}::response::IntoResponse for $outputName { + type Body = #{SmithyHttpServer}::Body; + type BodyError = ::Error; + + fn into_response(self) -> #{http}::Response { + match #{serialize_response}(&self.0) { + $handleSerializeOutput + } + } + }""".trimIndent(), + *codegenScope, + "O" to outputSymbol, + "serialize_response" to serverSerializeResponse(operationShape) + ) + } + + // Implement conversion function to "wrap" from the model operation output types. + if (operationShape.errors.isNotEmpty()) { + rustTemplate( + """ + impl From> for $outputName { + fn from(res: Result<#{O}, #{E}>) -> Self { + match res { + Ok(v) => Self::Output(v), + Err(e) => Self::Error(e), + } + } + } + """.trimIndent(), + "O" to outputSymbol, + "E" to errorSymbol + ) + } else { + rustTemplate( + """ + impl From<#{O}> for $outputName { + fn from(o: #{O}) -> Self { + Self(o) + } + } + """.trimIndent(), + "O" to outputSymbol + ) } + + // Implement conversion function to "unwrap" into the model operation input types. + rustTemplate( + """ + impl From<$inputName> for #{I} { + fn from(i: $inputName) -> Self { + i.0 + } + } + """.trimIndent(), + "I" to inputSymbol + ) } /* -- GitLab