diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 9df9b47c1def82be01f37c28942f0c9a0d2c1978..6739d955246749cc12874420feade839b8dc2ccc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -164,7 +164,7 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator: override fun structureShape(shape: StructureShape) { logger.info("[rust-server-codegen] Generating a structure $shape") rustCrate.useShapeWriter(shape) { writer -> - StructureGenerator(model, symbolProvider, writer, shape).render() + StructureGenerator(model, symbolProvider, writer, shape).renderServer() val builderGenerator = BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) builderGenerator.render(writer) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt similarity index 100% rename from codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeTypes.kt rename to codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt index d8c9057bfbb2569815b8ffeeb66b4060aa83451c..953717f42e60f463b4eae5512522c375c297bcb9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerCombinedErrorGenerator.kt @@ -68,12 +68,18 @@ class ServerCombinedErrorGenerator( rust("matches!(&self, ${symbol.name}::${errorSymbol.name}(_))") } } + writer.rust("/// Returns the error name string by matching the correct variant.") + writer.rustBlock("pub fn name(&self) -> &'static str") { + delegateToVariants { + rust("_inner.name()") + } + } } writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) { rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { delegateToVariants { - rust("Some(_inner)"); + rust("Some(_inner)") } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index b60faa41569d454511687f0d8f8eb45698c8a65a..da1a7153240acaf263591e4d7cdea22d4f2ad33e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -66,6 +66,13 @@ class ServerOperationHandlerGenerator( } else { "impl #{ServerOperationHandler}::Handler for Fun" } + val storeErrorInExtensions = """{ + let error = aws_smithy_http_server::ExtensionRejection::new(r.to_string()); + let mut response = r.into_response(); + response.extensions_mut().insert(error); + return response.map($serverCrate::boxed); + } + """.trimIndent() writer.rustBlockTemplate( """ ##[#{AsyncTrait}::async_trait] @@ -78,7 +85,7 @@ class ServerOperationHandlerGenerator( val callImpl = if (state) { """let state = match $serverCrate::Extension::::from_request(&mut req).await { Ok(v) => v, - Err(r) => return r.into_response().map($serverCrate::boxed) + Err(r) => $storeErrorInExtensions }; let input_inner = input_wrapper.into(); let output_inner = self(input_inner, state).await;""" @@ -95,7 +102,7 @@ class ServerOperationHandlerGenerator( use #{AxumCore}::response::IntoResponse; let input_wrapper = match $inputWrapperName::from_request(&mut req).await { Ok(v) => v, - Err(r) => return r.into_response().map(#{SmithyHttpServer}::boxed) + Err(r) => $storeErrorInExtensions }; $callImpl let output_wrapper: $outputWrapperName = output_inner.into(); diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneration.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt similarity index 100% rename from codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGeneration.kt rename to codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index b5f5d81f2a466cd4375fc3e53ef255d36587a7de..e1c79cad82e3a1266ef42c83680ae8613011dca2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -223,6 +223,7 @@ class ServerProtocolTestGenerator( checkBody(this, body, httpRequestTestCase) } } + checkHttpExtensions(this) // Explicitly warn if the test case defined parameters that we aren't doing anything with with(httpRequestTestCase) { @@ -322,6 +323,18 @@ class ServerProtocolTestGenerator( } } + private fun checkHttpExtensions(rustWriter: RustWriter) { + rustWriter.rust( + """ + let extensions = http_request.extensions().expect("unable to extract http request extensions"); + let namespace = extensions.get::().expect("extension ExtensionNamespace not found"); + assert_eq!(**namespace, ${operationShape.id.getNamespace().dq()}); + let operation_name = extensions.get::().expect("extension ExtensionOperationName not found"); + assert_eq!(**operation_name, ${operationSymbol.name.dq()}); + """.trimIndent() + ) + } + private fun checkHeaders(rustWriter: RustWriter, headers: Map) { if (headers.isEmpty()) { return diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index d68e92e6afef86c84e2f437e30473677d62d29e0..1fa28c7a412798558edae0818d7404fdd13c828b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -109,9 +109,8 @@ private class ServerHttpProtocolImplGenerator( override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) { val inputSymbol = symbolProvider.toSymbol(operationShape.inputShape(model)) val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) - val operationName = symbolProvider.toSymbol(operationShape).name - operationWriter.renderTraits(operationName, inputSymbol, outputSymbol, operationShape) + operationWriter.renderTraits(inputSymbol, outputSymbol, operationShape) } /* @@ -125,26 +124,29 @@ private class ServerHttpProtocolImplGenerator( * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. */ private fun RustWriter.renderTraits( - operationName: String?, inputSymbol: Symbol, outputSymbol: Symbol, operationShape: OperationShape ) { + val operationName = symbolProvider.toSymbol(operationShape).name // Implement Axum `FromRequest` trait for input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" + val httpExtensions = setHttpExtensions(operationShape) val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) { // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait. // It will first offer the streaming input to the parser and potentially read the body into memory // if an error occurred or if the streaming parser indicates that it needs the full data to proceed. """ - async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts) -> Result { + async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + $httpExtensions todo!("Streaming support for input shapes is not yet supported in `smithy-rs`") } """.trimIndent() } else { """ async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + $httpExtensions Ok($inputName(#{parse_request}(req).await?)) } """.trimIndent() @@ -185,13 +187,24 @@ private class ServerHttpProtocolImplGenerator( Self::Output(o) => { match #{serialize_response}(&o) { Ok(response) => response, - Err(e) => #{http}::Response::builder().body(#{SmithyHttpServer}::body::to_boxed(e.to_string())).expect("unable to build response from output") + Err(e) => { + let mut response = #{http}::Response::builder().body(#{SmithyHttpServer}::body::to_boxed(e.to_string())).expect("unable to build response from output"); + response.extensions_mut().insert(#{SmithyHttpServer}::ExtensionRejection::new(e.to_string())); + response + } } }, Self::Error(err) => { match #{serialize_error}(&err) { - Ok(response) => response, - Err(e) => #{http}::Response::builder().body(#{SmithyHttpServer}::body::to_boxed(e.to_string())).expect("unable to build response from error") + Ok(mut response) => { + response.extensions_mut().insert(aws_smithy_http_server::ExtensionModeledError::new(err.name())); + response + }, + Err(e) => { + let mut response = #{http}::Response::builder().body(#{SmithyHttpServer}::body::to_boxed(e.to_string())).expect("unable to build response from error"); + response.extensions_mut().insert(#{SmithyHttpServer}::ExtensionRejection::new(e.to_string())); + response + } } } } @@ -289,6 +302,19 @@ private class ServerHttpProtocolImplGenerator( ) } + /* + * Set `http::Extensions` for the current request. They can be used later for things like metrics, logging, etc.. + */ + private fun setHttpExtensions(operationShape: OperationShape): String { + val namespace = operationShape.id.getNamespace() + val operationName = symbolProvider.toSymbol(operationShape).name + return """ + let extensions = req.extensions_mut().ok_or(#{SmithyHttpServer}::rejection::ExtensionsAlreadyExtracted)?; + extensions.insert(#{SmithyHttpServer}::ExtensionNamespace::new(${namespace.dq()})); + extensions.insert(#{SmithyHttpServer}::ExtensionOperationName::new(${operationName.dq()})); + """.trimIndent() + } + private fun serverParseRequest(operationShape: OperationShape): RuntimeType { val fnName = "parse_${operationShape.id.name.toSnakeCase()}_request" val inputShape = operationShape.inputShape(model) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt index 6c7b0fc6b15c1fdcf2d6ebba4b7f98a601940bbf..e1d75e2348eb932314b74822617fd6db1d969a8f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt @@ -72,6 +72,13 @@ class StructureGenerator( } } + fun renderServer() { + renderStructure() + errorTrait?.also { errorTrait -> + ErrorGenerator(symbolProvider, writer, shape, errorTrait).renderServer() + } + } + companion object { /** Returns whether a structure shape requires a fallible builder to be generated. */ fun fallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/ErrorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/ErrorGenerator.kt index f0dc3464b6072b2be67ce87fb2b860b3e4139d55..39f6529f6edcb41611c908acbfaa3f2ed0a4b5c9 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/ErrorGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/error/ErrorGenerator.kt @@ -64,11 +64,22 @@ class ErrorGenerator( private val shape: StructureShape, private val error: ErrorTrait ) { + /* + * Renders an error specific for the client implementation. + */ fun render() { - renderError() + renderError(false) } - private fun renderError() { + /* + * Renders an error specific for the server implementation, where the [name] method is added to allow + * to record encoutered error types inside `http::Extensions`. + */ + fun renderServer() { + renderError(true) + } + + private fun renderError(isServer: Boolean) { val symbol = symbolProvider.toSymbol(shape) val messageShape = shape.errorMessageMember() val message = messageShape?.let { "self.${symbolProvider.toMemberName(it)}.as_deref()" } ?: "None" @@ -87,6 +98,17 @@ class ErrorGenerator( pub fn message(&self) -> Option<&str> { $message } """ ) + if (isServer) { + rust( + """ + ##[doc(hidden)] + /// Returns the error name. + pub fn name(&self) -> &'static str { + ${shape.id.name.dq()} + } + """ + ) + } } writer.rustBlock("impl #T for ${symbol.name}", stdfmt.member("Display")) { diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index f00f85b6222c9ac1e08b9265244a5f9b0df4da2f..e43edd0d7d2bfd83546fd71f99a2595b3274e316 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -34,11 +34,47 @@ //! Extension extraction to share state across handlers. -use super::rejection::{ExtensionRejection, ExtensionsAlreadyExtracted, MissingExtension}; +use super::rejection::{ExtensionHandlingRejection, ExtensionsAlreadyExtracted, MissingExtension}; use async_trait::async_trait; use axum_core::extract::{FromRequest, RequestParts}; use std::ops::Deref; +/// Extension type used to store the Smithy model namespace. +#[derive(Debug, Clone)] +pub struct ExtensionNamespace(&'static str); +impl_extension_new_and_deref!(ExtensionNamespace); + +/// Extension type used to store the Smithy operation name. +#[derive(Debug, Clone)] +pub struct ExtensionOperationName(&'static str); +impl_extension_new_and_deref!(ExtensionOperationName); + +/// Extension type used to store the type of user defined error returned by an operation. +/// These are modeled errors, defined in the Smithy model. +#[derive(Debug, Clone)] +pub struct ExtensionModeledError(&'static str); +impl_extension_new_and_deref!(ExtensionModeledError); + +/// Extension type used to store the type of framework error caught during execution. +/// These are unmodeled error, or rejection, defined in the runtime crates. +#[derive(Debug, Clone)] +pub struct ExtensionRejection(String); + +impl ExtensionRejection { + /// Returns a new `ExtensionRejection`. + pub fn new(value: String) -> ExtensionRejection { + ExtensionRejection(value) + } +} + +impl Deref for ExtensionRejection { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + /// Extractor that gets a value from [request extensions]. /// /// This is commonly used to share state across handlers. @@ -56,7 +92,7 @@ where T: Clone + Send + Sync + 'static, B: Send, { - type Rejection = ExtensionRejection; + type Rejection = ExtensionHandlingRejection; async fn from_request(req: &mut RequestParts) -> Result { let value = req diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 096c5a6f2d74e093a48f89c9c28186c1ef4495af..f68cbb36e7d7d3329c0d1b4fd41543819ace01ab 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -24,7 +24,9 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody}; #[doc(inline)] pub use self::error::Error; #[doc(inline)] -pub use self::extension::Extension; +pub use self::extension::{ + Extension, ExtensionModeledError, ExtensionNamespace, ExtensionOperationName, ExtensionRejection, +}; #[doc(inline)] pub use self::routing::Router; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server/src/macros.rs b/rust-runtime/aws-smithy-http-server/src/macros.rs index 89360bc8b84d320831c4018845fcbe6feea170b9..138e4cc8527dc7fb139ed783dd2e977530b89533 100644 --- a/rust-runtime/aws-smithy-http-server/src/macros.rs +++ b/rust-runtime/aws-smithy-http-server/src/macros.rs @@ -231,3 +231,23 @@ macro_rules! opaque_future { } pub use opaque_future; + +/// Implements `new` for all `Extension` holding a `&'static, str`. +macro_rules! impl_extension_new_and_deref { + ($name:ident) => { + impl $name { + #[doc = concat!("Returns a new `", stringify!($name), "`.")] + pub fn new(value: &'static str) -> $name { + $name(value) + } + } + + impl Deref for $name { + type Target = &'static str; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + }; +} diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index cede9b2c16b5fbaebc52d894d948d5ce682dd9d7..5e7a9a62046bbb31f3b4c7efef91b1772823e5fd 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -101,7 +101,7 @@ composite_rejection! { /// /// Contains one variant for each way the [`Extension`](super::Extension) extractor /// can fail. - pub enum ExtensionRejection { + pub enum ExtensionHandlingRejection { MissingExtension, ExtensionsAlreadyExtracted, } @@ -124,6 +124,7 @@ composite_rejection! { ContentTypeRejection, BodyAlreadyExtracted, HeadersAlreadyExtracted, + ExtensionsAlreadyExtracted, } }