From 64ccdccccf078e6e4332890790641e952b18a149 Mon Sep 17 00:00:00 2001 From: Matteo Bigoi <1781140+crisidev@users.noreply.github.com> Date: Wed, 15 Dec 2021 13:57:04 +0000 Subject: [PATCH] Use a single struct to represent the RequestExtension (#978) --- .../protocol/ServerProtocolTestGenerator.kt | 7 +-- .../protocols/ServerHttpProtocolGenerator.kt | 61 ++++++++++++------- .../aws-smithy-http-server/rustfmt.toml | 6 -- .../aws-smithy-http-server/src/extension.rs | 30 ++++++--- .../aws-smithy-http-server/src/lib.rs | 4 +- .../aws-smithy-http-server/src/macros.rs | 21 ++++--- .../aws-smithy-http-server/src/routing/mod.rs | 39 +++++++++--- .../src/routing/request_spec.rs | 33 ++++++++-- .../src/routing/route.rs | 8 ++- 9 files changed, 143 insertions(+), 66 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index e1c79cad8..0840a8f8d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -327,10 +327,9 @@ class ServerProtocolTestGenerator( rustWriter.rust( """ let extensions = http_request.extensions().expect("unable to extract http request extensions"); - let namespace = extensions.get::().expect("extension ExtensionNamespace not found"); - assert_eq!(**namespace, ${operationShape.id.getNamespace().dq()}); - let operation_name = extensions.get::().expect("extension ExtensionOperationName not found"); - assert_eq!(**operation_name, ${operationSymbol.name.dq()}); + let request_extensions = extensions.get::().expect("extension RequestExtensions not found"); + assert_eq!(request_extensions.namespace, ${operationShape.id.getNamespace().dq()}); + assert_eq!(request_extensions.operation_name, ${operationSymbol.name.dq()}); """.trimIndent() ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 09aba7492..ee07510c5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -316,8 +316,7 @@ private class ServerHttpProtocolImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name return """ let extensions = req.extensions_mut().ok_or(#{SmithyHttpServer}::rejection::ExtensionsAlreadyExtracted)?; - extensions.insert(#{SmithyHttpServer}::ExtensionNamespace::new(${namespace.dq()})); - extensions.insert(#{SmithyHttpServer}::ExtensionOperationName::new(${operationName.dq()})); + extensions.insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()})); """.trimIndent() } @@ -640,11 +639,13 @@ private class ServerHttpProtocolImplGenerator( } else if (targetMapValue.isSetShape) { QueryParamsTargetMapValueType.SET } else { - throw ExpectationNotMetException(""" + throw ExpectationNotMetException( + """ @httpQueryParams trait applied to non-supported target $targetMapValue of type ${targetMapValue.type} """.trimIndent(), - targetMapValue.sourceLocation) + targetMapValue.sourceLocation + ) } private fun serverRenderQueryStringParser(writer: RustWriter, operationShape: OperationShape) { @@ -661,7 +662,7 @@ private class ServerHttpProtocolImplGenerator( return } - fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType { + fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType { check(this.location == HttpLocation.QUERY_PARAMS) val queryParamsTarget = model.expectShape(this.member.target) val mapTarget = queryParamsTarget.asMapShape().get() @@ -669,7 +670,8 @@ private class ServerHttpProtocolImplGenerator( } with(writer) { - rustTemplate(""" + rustTemplate( + """ let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?; let pairs = #{SerdeUrlEncoded}::from_str::>(query_string)?; """.trimIndent(), @@ -677,7 +679,8 @@ private class ServerHttpProtocolImplGenerator( ) if (queryParamsBinding != null) { - rustTemplate("let mut query_params: #{HashMap} = #{HashMap}::new();", "HashMap" to RustType.HashMap.RuntimeType, ) @@ -694,15 +697,17 @@ private class ServerHttpProtocolImplGenerator( rustBlock("for (k, v) in pairs") { queryBindingsTargettingSimple.forEach { val deserializer = generateParsePercentEncodedStrFn(it) - rustTemplate(""" + rustTemplate( + """ if !seen_${it.memberName.toSnakeCase()} && k == "${it.locationName}" { input = input.${it.member.setterName()}( #{deserializer}(v)? ); seen_${it.memberName.toSnakeCase()} = true; } - """.trimIndent(), - "deserializer" to deserializer) + """.trimIndent(), + "deserializer" to deserializer + ) } queryBindingsTargettingCollection.forEach { rustBlock("if k == ${it.locationName.dq()}") { @@ -714,9 +719,12 @@ private class ServerHttpProtocolImplGenerator( // `<_>::from()` is necessary to convert the `&str` into: // * the Rust enum in case the `string` shape has the `enum` trait; or // * `String` in case it doesn't. - rustTemplate(""" + rustTemplate( + """ let v = <_>::from(#{PercentEncoding}::percent_decode_str(v).decode_utf8()?.as_ref()); - """.trimIndent(), *codegenScope) + """.trimIndent(), + *codegenScope + ) } memberShape.isTimestampShape -> { val index = HttpBindingIndex.of(model) @@ -727,18 +735,22 @@ private class ServerHttpProtocolImplGenerator( protocol.defaultTimestampFormat, ) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) - rustTemplate(""" + rustTemplate( + """ let v = #{PercentEncoding}::percent_decode_str(v).decode_utf8()?; let v = #{DateTime}::from_str(&v, #{format})?; - """.trimIndent(), + """.trimIndent(), *codegenScope, "format" to timestampFormatType, ) } else -> { // Number or boolean. - rust(""" + rust( + """ let v = <_ as #T>::parse_smithy_primitive(v)?; - """.trimIndent(), CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")) + """.trimIndent(), + CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse") + ) } } rust("${it.memberName.toSnakeCase()}.push(v);") @@ -750,10 +762,12 @@ private class ServerHttpProtocolImplGenerator( QueryParamsTargetMapValueType.STRING -> { rust("query_params.entry(String::from(k)).or_insert_with(|| String::from(v));") } else -> { - rustTemplate(""" + rustTemplate( + """ let entry = query_params.entry(String::from(k)).or_default(); entry.push(String::from(v)); - """.trimIndent()) + """.trimIndent() + ) } } } @@ -762,9 +776,11 @@ private class ServerHttpProtocolImplGenerator( rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));") } queryBindingsTargettingCollection.forEach { - rustTemplate(""" + rustTemplate( + """ input = input.${it.member.setterName()}(Some(${it.memberName.toSnakeCase()})); - """.trimIndent()) + """.trimIndent() + ) } } } @@ -810,10 +826,11 @@ private class ServerHttpProtocolImplGenerator( // `<_>::from()` is necessary to convert the `&str` into: // * the Rust enum in case the `string` shape has the `enum` trait; or // * `String` in case it doesn't. - rustTemplate(""" + rustTemplate( + """ let value = <_>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref()); Ok(Some(value)) - """.trimIndent(), + """.trimIndent(), *codegenScope, ) } diff --git a/rust-runtime/aws-smithy-http-server/rustfmt.toml b/rust-runtime/aws-smithy-http-server/rustfmt.toml index 76072d80f..b95a9773a 100644 --- a/rust-runtime/aws-smithy-http-server/rustfmt.toml +++ b/rust-runtime/aws-smithy-http-server/rustfmt.toml @@ -1,10 +1,4 @@ edition = "2018" max_width = 120 -# The "Default" setting has a heuristic which splits lines too aggresively. -# We are willing to revisit this setting in future versions of rustfmt. -# Bugs: -# * https://github.com/rust-lang/rustfmt/issues/3119 -# * https://github.com/rust-lang/rustfmt/issues/3120 -use_small_heuristics = "Max" # Prevent carriage returns newline_style = "Unix" diff --git a/rust-runtime/aws-smithy-http-server/src/extension.rs b/rust-runtime/aws-smithy-http-server/src/extension.rs index e43edd0d7..1355c9e69 100644 --- a/rust-runtime/aws-smithy-http-server/src/extension.rs +++ b/rust-runtime/aws-smithy-http-server/src/extension.rs @@ -39,15 +39,29 @@ use async_trait::async_trait; use axum_core::extract::{FromRequest, RequestParts}; use std::ops::Deref; -/// Extension type used to store the Smithy model namespace. -#[derive(Debug, Clone)] -pub struct ExtensionNamespace(&'static str); -impl_extension_new_and_deref!(ExtensionNamespace); +/// Extension type used to store Smithy request information. +#[derive(Debug, Clone, Default, Copy)] +pub struct RequestExtensions { + /// Smithy model namespace. + pub namespace: &'static str, + /// Smithy operation name. + pub operation_name: &'static str, +} -/// Extension type used to store the Smithy operation name. -#[derive(Debug, Clone)] -pub struct ExtensionOperationName(&'static str); -impl_extension_new_and_deref!(ExtensionOperationName); +impl RequestExtensions { + /// Generates a new `RequestExtensions`. + pub fn new(namespace: &'static str, operation_name: &'static str) -> Self { + Self { + namespace, + operation_name, + } + } + + /// Returns the current operation formatted as #. + pub fn operation(&self) -> String { + format!("{}#{}", self.namespace, self.operation_name) + } +} /// Extension type used to store the type of user defined error returned by an operation. /// These are modeled errors, defined in the Smithy model. diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index f68cbb36e..95302b502 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -24,9 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody}; #[doc(inline)] pub use self::error::Error; #[doc(inline)] -pub use self::extension::{ - Extension, ExtensionModeledError, ExtensionNamespace, ExtensionOperationName, ExtensionRejection, -}; +pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions}; #[doc(inline)] pub use self::routing::Router; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server/src/macros.rs b/rust-runtime/aws-smithy-http-server/src/macros.rs index 138e4cc85..2485241fc 100644 --- a/rust-runtime/aws-smithy-http-server/src/macros.rs +++ b/rust-runtime/aws-smithy-http-server/src/macros.rs @@ -232,6 +232,19 @@ macro_rules! opaque_future { pub use opaque_future; +/// Implements `Deref` for all `Extension` holding a `&'static, str`. +macro_rules! impl_deref { + ($name:ident) => { + impl Deref for $name { + type Target = &'static str; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + }; +} + /// Implements `new` for all `Extension` holding a `&'static, str`. macro_rules! impl_extension_new_and_deref { ($name:ident) => { @@ -242,12 +255,6 @@ macro_rules! impl_extension_new_and_deref { } } - impl Deref for $name { - type Target = &'static str; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } + impl_deref!($name); }; } diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 1cb834e59..5f80442be 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -43,7 +43,9 @@ pub struct Router { impl Clone for Router { fn clone(&self) -> Self { - Self { routes: self.routes.clone() } + Self { + routes: self.routes.clone(), + } } } @@ -66,7 +68,9 @@ where /// all requests. #[doc(hidden)] pub fn new() -> Self { - Self { routes: Default::default() } + Self { + routes: Default::default(), + } } /// Add a route to the router. @@ -107,9 +111,15 @@ where NewResBody: HttpBody + Send + 'static, NewResBody::Error: Into, { - let layer = ServiceBuilder::new().layer_fn(Route::new).layer(MapResponseBodyLayer::new(boxed)).layer(layer); - let routes = - self.routes.into_iter().map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)).collect(); + let layer = ServiceBuilder::new() + .layer_fn(Route::new) + .layer(MapResponseBodyLayer::new(boxed)) + .layer(layer); + let routes = self + .routes + .into_iter() + .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)) + .collect(); Router { routes } } } @@ -142,8 +152,17 @@ where } } - let status_code = if method_not_allowed { StatusCode::METHOD_NOT_ALLOWED } else { StatusCode::NOT_FOUND }; - RouterFuture::from_response(Response::builder().status(status_code).body(crate::body::empty()).unwrap()) + let status_code = if method_not_allowed { + StatusCode::METHOD_NOT_ALLOWED + } else { + StatusCode::NOT_FOUND + }; + RouterFuture::from_response( + Response::builder() + .status(status_code) + .body(crate::body::empty()) + .unwrap(), + ) } } @@ -201,7 +220,11 @@ mod tests { ( RequestSpec::from_parts( Method::GET, - vec![PathSegment::Literal(String::from("a")), PathSegment::Label, PathSegment::Label], + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Label, + PathSegment::Label, + ], vec![], ), "A", diff --git a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs index e2820f92b..9109e269e 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs @@ -96,7 +96,11 @@ impl From<&PathSpec> for Regex { impl RequestSpec { pub fn new(method: http::Method, uri_spec: UriSpec) -> Self { let uri_path_regex = (&uri_spec.path_and_query.path_segments).into(); - RequestSpec { method, uri_spec, uri_path_regex } + RequestSpec { + method, + uri_spec, + uri_path_regex, + } } pub(super) fn matches(&self, req: &Request) -> Match { @@ -233,25 +237,38 @@ mod tests { #[test] fn repeated_query_keys_same_values_match() { - assert_eq!(Match::Yes, key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar"))); + assert_eq!( + Match::Yes, + key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar")) + ); } #[test] fn repeated_query_keys_distinct_values_does_not_match() { - assert_eq!(Match::No, key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz"))); + assert_eq!( + Match::No, + key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz")) + ); } fn ab_spec() -> RequestSpec { RequestSpec::from_parts( Method::GET, - vec![PathSegment::Literal(String::from("a")), PathSegment::Literal(String::from("b"))], + vec![ + PathSegment::Literal(String::from("a")), + PathSegment::Literal(String::from("b")), + ], vec![], ) } #[test] fn empty_segments_in_the_middle_dont_matter() { - let hits = vec![(Method::GET, "/a/b"), (Method::GET, "/a//b"), (Method::GET, "//////a//b")]; + let hits = vec![ + (Method::GET, "/a/b"), + (Method::GET, "/a//b"), + (Method::GET, "//////a//b"), + ]; for (method, uri) in &hits { assert_eq!(Match::Yes, ab_spec().matches(&req(method, uri))); } @@ -262,7 +279,11 @@ mod tests { // end of URIs _do_ matter. #[test] fn empty_segments_at_the_end_do_matter() { - let misses = vec![(Method::GET, "/a/b/"), (Method::GET, "/a/b//"), (Method::GET, "//a//b////")]; + let misses = vec![ + (Method::GET, "/a/b/"), + (Method::GET, "/a/b//"), + (Method::GET, "//a//b////"), + ]; for (method, uri) in &misses { assert_eq!(Match::No, ab_spec().matches(&req(method, uri))); } diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 8358feea1..2b6fe3227 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -57,13 +57,17 @@ impl Route { T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, { - Self { service: BoxCloneService::new(svc) } + Self { + service: BoxCloneService::new(svc), + } } } impl Clone for Route { fn clone(&self) -> Self { - Self { service: self.service.clone() } + Self { + service: self.service.clone(), + } } } -- GitLab