diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index 56952ff12de279229992f55bed9d68ddb99b5141..f0f156dfcf77e6cdc1caab2b5147aec65cf103b5 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -36,6 +36,8 @@ val allCodegenTests = listOf( CodegenTest("com.amazonaws.simple#SimpleService", "simple"), CodegenTest("aws.protocoltests.restjson#RestJson", "rest_json"), CodegenTest("aws.protocoltests.restjson.validation#RestJsonValidation", "rest_json_validation"), + CodegenTest("aws.protocoltests.json10#JsonRpc10", "json_rpc10"), + CodegenTest("aws.protocoltests.json#JsonProtocol", "json_rpc11"), CodegenTest("aws.protocoltests.misc#MiscService", "misc"), CodegenTest("com.amazonaws.ebs#Ebs", "ebs"), CodegenTest("com.amazonaws.s3#AmazonS3", "s3"), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index c959b03f0f382c21393369cb720c1fbb7770ff3b..898ec24e5207b81566919af6f8b80d25cfaa6e57 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -38,4 +38,7 @@ object ServerRuntimeType { fun ResponseRejection(runtimeConfig: RuntimeConfig) = RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection") + + fun Protocol(runtimeConfig: RuntimeConfig) = + RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols") } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt index 78fb063129b9b08558ba3825bf29c8b757419cbc..ba14ee7c922c75d88f8e7383bae13726b06b580c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationRegistryGenerator.kt @@ -5,6 +5,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustWriter @@ -28,7 +32,9 @@ class ServerOperationRegistryGenerator( private val httpBindingResolver: HttpBindingResolver, private val operations: List, ) { + private val protocol = codegenContext.protocol private val symbolProvider = codegenContext.symbolProvider + private val serviceName = codegenContext.serviceShape.toShapeId().name private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() } private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope = arrayOf( @@ -223,7 +229,7 @@ class ServerOperationRegistryGenerator( rustTemplate( """ $requestSpecs - #{Router}::from_box_clone_service_iter($towerServices) + #{Router}::${runtimeRouterConstructor()}($towerServices) """.trimIndent(), *codegenScope ) @@ -241,12 +247,42 @@ class ServerOperationRegistryGenerator( } /* - * Generate the `RequestSpec`s for an operation based on its HTTP-bound route. + * Finds the runtime function to construct a new `Router` based on the Protocol. */ - private fun OperationShape.requestSpec(): String { + private fun runtimeRouterConstructor(): String = + when (protocol) { + RestJson1Trait.ID -> "new_rest_json_router" + RestXmlTrait.ID -> "new_rest_xml_router" + AwsJson1_0Trait.ID -> "new_aws_json_10_router" + AwsJson1_1Trait.ID -> "new_aws_json_11_router" + else -> TODO("Protocol $protocol not supported yet") + } + + /* + * Returns the `RequestSpec`s for an operation based on its HTTP-bound route. + */ + private fun OperationShape.requestSpec(): String = + when (protocol) { + RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec() + AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName() + else -> TODO("Protocol $protocol not supported yet") + } + + /* + * Returns an AwsJson specific runtime `RequestSpec`. + */ + private fun OperationShape.awsJsonOperationName(): String { + val operationName = symbolProvider.toSymbol(this).name + // TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait + return """String::from("$serviceName.$operationName")""" + } + + /* + * Generates a REST (RestJson1, RestXml) specific runtime `RequestSpec`. + */ + private fun OperationShape.restRequestSpec(): String { val httpTrait = httpBindingResolver.httpTrait(this) val namespace = ServerRuntimeType.RequestSpecModule(runtimeConfig).fullyQualifiedName() - // TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait. val pathSegments = httpTrait.uri.segments.map { "$namespace::PathSegment::" + @@ -268,7 +304,7 @@ class ServerOperationRegistryGenerator( $namespace::PathSpec::from_vector_unchecked(vec![${pathSegments.joinToString()}]), $namespace::QuerySpec::from_vector_unchecked(vec![${querySegments.joinToString()}]) ) - ) + ), ) """.trimIndent() } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 6a684898ae39033147128478347d68a35288acf2..96faa29fcfec68a1449c8eec8cded8012ebb6d73 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -437,8 +437,8 @@ class ServerProtocolTestGenerator( else -> { rustWriter.rustTemplate( """ - #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); - """, + #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`"); + """, *codegenScope ) } @@ -798,6 +798,18 @@ class ServerProtocolTestGenerator( FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostAccelerateAddressing", TestType.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3VirtualHostDualstackAccelerateAddressing", TestType.Request), FailingTest("com.amazonaws.s3#AmazonS3", "S3OperationAddressingPreferred", TestType.Request), + + // AwsJson1.0 failing tests. + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTraitWithHostLabel", TestType.Request), + FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request), + + // AwsJson1.1 failing tests. + FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request), + FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request), + FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response), + FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response), + FailingTest("aws.protocoltests.json#JsonProtocol", "parses_the_request_id_from_the_response", TestType.Response), + ) private val RunOnly: Set? = null diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt new file mode 100644 index 0000000000000000000000000000000000000000..67209d17f405665822474e4c43dea2ee1df26782 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.server.smithy.protocols + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.escape +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJson +import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver +import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory +import software.amazon.smithy.rust.codegen.smithy.protocols.awsJsonFieldName +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonCustomization +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSection +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator +import software.amazon.smithy.rust.codegen.util.hasTrait + +/* + * AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator] + * with AwsJson specific configurations. + */ +class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { + override fun protocol(codegenContext: CodegenContext): Protocol = ServerAwsJson(codegenContext, version) + + override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator = + ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) + + override fun transformModel(model: Model): Model = model + + override fun support(): ProtocolSupport { + return ProtocolSupport( + /* Client support */ + requestSerialization = false, + requestBodySerialization = false, + responseDeserialization = false, + errorDeserialization = false, + /* Server support */ + requestDeserialization = true, + requestBodyDeserialization = true, + responseSerialization = true, + errorSerialization = true + ) + } +} + +/** + * AwsJson requires errors to be serialized with an additional "__type" field. This + * customization writes the right field depending on the version of the AwsJson protocol. + */ +class ServerAwsJsonError(private val awsJsonVersion: AwsJsonVersion) : JsonCustomization() { + override fun section(section: JsonSection): Writable = when (section) { + is JsonSection.ServerError -> writable { + if (section.structureShape.hasTrait()) { + val typeId = when (awsJsonVersion) { + // AwsJson 1.0 wants the whole shape ID (namespace#Shape). + // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#operation-error-serialization + AwsJsonVersion.Json10 -> section.structureShape.id.toString() + // AwsJson 1.1 wants only the shape name (Shape). + // https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#operation-error-serialization + AwsJsonVersion.Json11 -> section.structureShape.id.name.toString() + } + rust("""${section.jsonObject}.key("__type").string("${escape(typeId)}");""") + } + } + } +} + +/** + * AwsJson requires errors to be serialized with an additional "__type" field. This class + * customizes [JsonSerializerGenerator] to add this functionality. + */ +class ServerAwsJsonSerializerGenerator( + private val codegenContext: CodegenContext, + private val httpBindingResolver: HttpBindingResolver, + private val awsJsonVersion: AwsJsonVersion, + private val jsonSerializerGenerator: JsonSerializerGenerator = + JsonSerializerGenerator(codegenContext, httpBindingResolver, ::awsJsonFieldName, customizations = listOf(ServerAwsJsonError(awsJsonVersion))) +) : StructuredDataSerializerGenerator by jsonSerializerGenerator + +class ServerAwsJson( + private val codegenContext: CodegenContext, + private val awsJsonVersion: AwsJsonVersion +) : AwsJson(codegenContext, awsJsonVersion) { + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + ServerAwsJsonSerializerGenerator(codegenContext, httpBindingResolver, awsJsonVersion) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 9b8b9b29648b7be37340be68b4e9dc72e8efe97a..745ebbcc275bb81336f2069f20000af2517d177b 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.codegen.core.Symbol @@ -1064,10 +1066,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private fun getContentTypeCheck(): String { when (codegenContext.protocol) { RestJson1Trait.ID -> { - return "check_json_content_type" + return "check_rest_json_1_content_type" } RestXmlTrait.ID -> { - return "check_xml_content_type" + return "check_rest_xml_content_type" + } + AwsJson1_0Trait.ID -> { + return "check_aws_json_10_content_type" + } + AwsJson1_1Trait.ID -> { + return "check_aws_json_11_content_type" } else -> { TODO("Protocol ${codegenContext.protocol} not supported yet") @@ -1086,7 +1094,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( return ServerRuntimeType.RequestRejection(runtimeConfig) } when (codegenContext.protocol) { - RestJson1Trait.ID -> { + RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> { return CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error") } RestXmlTrait.ID -> { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index 94bb5407c0dca10ee94e153b7a8d50b90ba9c654..97069b27d683896aeae2e93870d0b609e0a67078 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.codegen.core.CodegenException @@ -14,6 +16,7 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap @@ -39,6 +42,8 @@ class ServerProtocolLoader(private val supportedProtocols: ProtocolMap) { val DefaultProtocols = mapOf( RestJson1Trait.ID to ServerRestJsonFactory(), RestXmlTrait.ID to ServerRestXmlFactory(), + AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10), + AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXml.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXml.kt index 80c546c4ae5b2020b586e970baae04fb766d96fc..1fc094fe189a8dbf3d87c9079420af5d7481f3ee 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXml.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXml.kt @@ -16,9 +16,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml * RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] * with RestXml specific configurations. */ -class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) : - ProtocolGeneratorFactory { - override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext) +class ServerRestXmlFactory : ProtocolGeneratorFactory { + override fun protocol(codegenContext: CodegenContext): Protocol = RestXml(codegenContext) override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator(codegenContext, RestXml(codegenContext)) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt index 8ad3813b604dcdfacb6d2e2bb6c5fed56d9269d0..41b837a3ab3fbdd614eb09c0fa9b50281e840ea8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt @@ -128,7 +128,7 @@ class AwsJsonSerializerGenerator( } } -class AwsJson( +open class AwsJson( private val codegenContext: CodegenContext, awsJsonVersion: AwsJsonVersion ) : Protocol { @@ -183,6 +183,4 @@ class AwsJson( } } -private fun awsJsonFieldName(member: MemberShape): String { - return member.memberName -} +fun awsJsonFieldName(member: MemberShape): String = member.memberName diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt index 2272eff87ea50816a07a23163c72eea5431b1998..e5ef441119e86fd4e8cee1144f86af965fbc6bd2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.customize.NamedSectionGenerator +import software.amazon.smithy.rust.codegen.smithy.customize.Section import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.smithy.generators.serializationError @@ -49,11 +51,25 @@ import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape +/** + * Class describing a JSON section that can be used in a customization. + */ +sealed class JsonSection(name: String) : Section(name) { + /** Mutate the server error object prior to finalization. Eg: this can be used to inject `__type` to record the error type. */ + data class ServerError(val structureShape: StructureShape, val jsonObject: String) : JsonSection("ServerError") +} + +/** + * JSON customization. + */ +typealias JsonCustomization = NamedSectionGenerator + class JsonSerializerGenerator( codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, + private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { private data class Context( /** Expression that retrieves a JsonValueWriter from either a JsonObjectWriter or JsonArrayWriter */ @@ -155,7 +171,7 @@ class JsonSerializerGenerator( private fun serverStructureSerializer( fnName: String, structureShape: StructureShape, - includedMembers: List + includedMembers: List, ): RuntimeType { return RuntimeType.forInlineFun(fnName, operationSerModule) { it.rustBlockTemplate( @@ -166,6 +182,7 @@ class JsonSerializerGenerator( rust("let mut out = String::new();") rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) serializeStructure(StructContext("object", "value", structureShape), includedMembers) + customizations.forEach { it.section(JsonSection.ServerError(structureShape, "object"))(this) } rust("object.finish();") rustTemplate("Ok(out)", *codegenScope) } diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 49f9bff63bc41decd7260ac14acefa97a7706d25..2a7f1ae82ca8a1db298b62067a6c6d984e9e5195 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -29,6 +29,7 @@ http-body = "0.4" hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] } mime = "0.3" nom = "7" +paste = "1" pin-project-lite = "0.2" regex = "1.0" serde_urlencoded = "0.7" diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index a1113977a6e76a778dc90e49d5a3433ded9e3a67..473b78ccb7ad3cdba7d2c4fba5f9e51cdf641011 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -6,51 +6,206 @@ //! Protocol helpers. use crate::rejection::RequestRejection; use axum_core::extract::RequestParts; +use paste::paste; -#[derive(Debug)] +/// Supported protocols. +#[derive(Debug, Clone, Copy)] pub enum Protocol { RestJson1, RestXml, + AwsJson10, + AwsJson11, } -/// Validate that the request had the standard JSON content-type header. -pub fn check_json_content_type(req: &RequestParts) -> Result<(), RequestRejection> { - let mime = req - .headers() - .ok_or(RequestRejection::MissingJsonContentType)? - .get(http::header::CONTENT_TYPE) - .ok_or(RequestRejection::MissingJsonContentType)? - .to_str() - .map_err(|_| RequestRejection::MissingJsonContentType)? - .parse::() - .map_err(|_| RequestRejection::MimeParse)?; - - if mime.type_() == "application" - && (mime.subtype() == "json" || mime.suffix().filter(|name| *name == "json").is_some()) - { - Ok(()) - } else { - Err(RequestRejection::MissingJsonContentType) - } +/// Implement the content-type header validation for a request. +macro_rules! impl_content_type_validation { + ($name:literal, $type: literal, $subtype:literal, $rejection:path) => { + paste! { + #[doc = concat!("Validates that the request has the standard `", $type, "/", $subtype, "` content-type header.")] + pub fn [](req: &RequestParts) -> Result<(), RequestRejection> { + let mime = req + .headers() + .ok_or($rejection)? + .get(http::header::CONTENT_TYPE) + .ok_or($rejection)? + .to_str() + .map_err(|_| $rejection)? + .parse::() + .map_err(|_| RequestRejection::MimeParse)?; + if mime.type_() == $type && mime.subtype() == $subtype { + Ok(()) + } else { + Err($rejection) + } + } + } + }; } -/// Validate that the request had the standard XML content-type header. -pub fn check_xml_content_type(req: &RequestParts) -> Result<(), RequestRejection> { - let mime = req - .headers() - .ok_or(RequestRejection::MissingXmlContentType)? - .get(http::header::CONTENT_TYPE) - .ok_or(RequestRejection::MissingXmlContentType)? - .to_str() - .map_err(|_| RequestRejection::MissingXmlContentType)? - .parse::() - .map_err(|_| RequestRejection::MimeParse)?; - - if mime.type_() == "application" - && (mime.subtype() == "xml" || mime.suffix().filter(|name| *name == "xml").is_some()) - { - Ok(()) - } else { - Err(RequestRejection::MissingXmlContentType) +impl_content_type_validation!( + "rest_json_1", + "application", + "json", + RequestRejection::MissingRestJson1ContentType +); + +impl_content_type_validation!( + "rest_xml", + "application", + "xml", + RequestRejection::MissingRestXmlContentType +); + +impl_content_type_validation!( + "aws_json_10", + "application", + "x-amz-json-1.0", + RequestRejection::MissingAwsJson10ContentType +); + +impl_content_type_validation!( + "aws_json_11", + "application", + "x-amz-json-1.1", + RequestRejection::MissingAwsJson11ContentType +); + +#[cfg(test)] +mod tests { + use super::*; + use http::Request; + + fn req(content_type: &str) -> RequestParts<&str> { + let request = Request::builder() + .header("content-type", content_type) + .body("") + .unwrap(); + RequestParts::new(request) + } + + /// This macro validates the rejection type since we cannot implement `PartialEq` + /// for `RequestRejection` as it is based on the crate error type, which uses + /// `axum_core::BoxError`. + macro_rules! validate_rejection_type { + ($result:expr, $rejection:path) => { + match $result { + Ok(()) => panic!("Content-type validation is expected to fail"), + Err(e) => match e { + $rejection => {} + _ => panic!("Error {} should be {}", e.to_string(), stringify!($rejection)), + }, + } + }; + } + + #[test] + fn validate_rest_json_1_content_type() { + // Check valid content-type header. + let request = req("application/json"); + assert!(check_rest_json_1_content_type(&request).is_ok()); + + // Check invalid content-type header. + let invalid = vec![ + req("application/ajson"), + req("application/json1"), + req("applicatio/json"), + req("application/xml"), + req("text/xml"), + req("application/x-amz-json-1.0"), + req("application/x-amz-json-1.1"), + RequestParts::new(Request::builder().body("").unwrap()), + ]; + for request in &invalid { + validate_rejection_type!( + check_rest_json_1_content_type(request), + RequestRejection::MissingRestJson1ContentType + ); + } + + // Check request with not parsable content-type header. + validate_rejection_type!(check_rest_json_1_content_type(&req("123")), RequestRejection::MimeParse); + } + + #[test] + fn validate_rest_xml_content_type() { + // Check valid content-type header. + let request = req("application/xml"); + assert!(check_rest_xml_content_type(&request).is_ok()); + + // Check invalid content-type header. + let invalid = vec![ + req("application/axml"), + req("application/xml1"), + req("applicatio/xml"), + req("text/xml"), + req("application/x-amz-json-1.0"), + req("application/x-amz-json-1.1"), + RequestParts::new(Request::builder().body("").unwrap()), + ]; + for request in &invalid { + validate_rejection_type!( + check_rest_xml_content_type(request), + RequestRejection::MissingRestXmlContentType + ); + } + + // Check request with not parsable content-type header. + validate_rejection_type!(check_rest_xml_content_type(&req("123")), RequestRejection::MimeParse); + } + + #[test] + fn validate_aws_json_10_content_type() { + // Check valid content-type header. + let request = req("application/x-amz-json-1.0"); + assert!(check_aws_json_10_content_type(&request).is_ok()); + + // Check invalid content-type header. + let invalid = vec![ + req("application/x-amz-json-1."), + req("application/-amz-json-1.0"), + req("application/xml"), + req("application/json"), + req("applicatio/x-amz-json-1.0"), + req("text/xml"), + req("application/x-amz-json-1.1"), + RequestParts::new(Request::builder().body("").unwrap()), + ]; + for request in &invalid { + validate_rejection_type!( + check_aws_json_10_content_type(request), + RequestRejection::MissingAwsJson10ContentType + ); + } + + // Check request with not parsable content-type header. + validate_rejection_type!(check_aws_json_10_content_type(&req("123")), RequestRejection::MimeParse); + } + + #[test] + fn validate_aws_json_11_content_type() { + // Check valid content-type header. + let request = req("application/x-amz-json-1.1"); + assert!(check_aws_json_11_content_type(&request).is_ok()); + + // Check invalid content-type header. + let invalid = vec![ + req("application/x-amz-json-1."), + req("application/-amz-json-1.1"), + req("application/xml"), + req("application/json"), + req("applicatio/x-amz-json-1.1"), + req("text/xml"), + req("application/x-amz-json-1.0"), + RequestParts::new(Request::builder().body("").unwrap()), + ]; + for request in &invalid { + validate_rejection_type!( + check_aws_json_11_content_type(request), + RequestRejection::MissingAwsJson11ContentType + ); + } + + // Check request with not parsable content-type header. + validate_rejection_type!(check_aws_json_11_content_type(&req("123")), RequestRejection::MimeParse); } } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 9f96f32b0110b44c4518aec11f30dc4f15baaf46..7de0872b71530f0f4634f6d58a9f66d77c806ec2 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -134,8 +134,10 @@ pub enum RequestRejection { HttpBody(crate::Error), // These are used when checking the `Content-Type` header. - MissingJsonContentType, - MissingXmlContentType, + MissingRestJson1ContentType, + MissingAwsJson10ContentType, + MissingAwsJson11ContentType, + MissingRestXmlContentType, MimeParse, /// Used when failing to deserialize the HTTP body's bytes into a JSON document conforming to diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 335a951a393d93e1316d6a457b0d80a5ff6be782..8ec4562fa44febf7716fe53c46aa380b10e7cde0 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -7,10 +7,15 @@ //! //! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html -use self::{future::RouterFuture, request_spec::RequestSpec}; +use self::future::RouterFuture; +use self::request_spec::RequestSpec; use crate::body::{boxed, Body, BoxBody, HttpBody}; +use crate::protocols::Protocol; +use crate::runtime_error::{RuntimeError, RuntimeErrorKind}; use crate::BoxError; +use axum_core::response::IntoResponse; use http::{Request, Response, StatusCode}; +use std::collections::HashMap; use std::{ convert::Infallible, task::{Context, Poll}, @@ -31,34 +36,60 @@ mod route; pub use self::{into_make_service::IntoMakeService, route::Route}; /// The router is a [`tower::Service`] that routes incoming requests to other `Service`s -/// based on the request's URI and HTTP method, adhering to the [Smithy specification]. +/// based on the request's URI and HTTP method or on some specific header setting the target operation. +/// The former is adhering to the [Smithy specification], while the latter is adhering to +/// the [AwsJson specification]. +/// +/// The router is also [Protocol] aware and currently supports REST based protocols like [restJson1] or [restXml] +/// and RPC based protocols like [awsJson1.0] or [awsJson1.1]. /// It currently does not support Smithy's [endpoint trait]. /// /// You should not **instantiate** this router directly; it will be created for you from the /// code generated from your Smithy model by `smithy-rs`. /// /// [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html +/// [AwsJson specification]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#protocol-behaviors +/// [Protocol]: https://awslabs.github.io/smithy/1.0/spec/aws/index.html#aws-protocols +/// [restJson1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restjson1-protocol.html +/// [restXml]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html +/// [awsJson1.0]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html +/// [awsJson1.1]: https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html /// [endpoint trait]: https://awslabs.github.io/smithy/1.0/spec/core/endpoint-traits.html#endpoint-trait #[derive(Debug)] pub struct Router { - routes: Vec<(Route, RequestSpec)>, + routes: Routes, } -impl Clone for Router { - fn clone(&self) -> Self { - Self { - routes: self.routes.clone(), - } - } +/// Protocol-aware routes types. +/// +/// RestJson1 and RestXml routes are stored in a `Vec` because there can be multiple matches on the +/// request URI and we thus need to iterate the whole list and use a ranking mechanism to choose. +/// +/// AwsJson 1.0 and 1.1 routes can be stored in a `HashMap` since the requested operation can be +/// directly found in the `X-Amz-Target` HTTP header. +#[derive(Debug)] +enum Routes { + RestXml(Vec<(Route, RequestSpec)>), + RestJson1(Vec<(Route, RequestSpec)>), + AwsJson10(HashMap>), + AwsJson11(HashMap>), } -impl Default for Router -where - B: Send + 'static, -{ - fn default() -> Self { - Self { - routes: Default::default(), +impl Clone for Router { + fn clone(&self) -> Self { + match &self.routes { + Routes::RestJson1(routes) => Router { + routes: Routes::RestJson1(routes.clone()), + }, + Routes::RestXml(routes) => Router { + routes: Routes::RestXml(routes.clone()), + }, + Routes::AwsJson10(routes) => Router { + routes: Routes::AwsJson10(routes.clone()), + }, + Routes::AwsJson11(routes) => Router { + routes: Routes::AwsJson11(routes.clone()), + }, } } } @@ -67,32 +98,29 @@ impl Router where B: Send + 'static, { - /// Create a new `Router` from a vector of pairs of request specs and services. - /// - /// If the vector is empty the router will respond `404 Not Found` to all requests. - #[doc(hidden)] - pub fn from_box_clone_service_iter(routes: T) -> Self - where - T: IntoIterator< - Item = ( - tower::util::BoxCloneService, Response, Infallible>, - RequestSpec, - ), - >, - { - let mut routes: Vec<(Route, RequestSpec)> = routes - .into_iter() - .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec)) - .collect(); - - // Sort them once by specifity, with the more specific routes sorted before the less - // specific ones, so that when routing a request we can simply iterate through the routes - // and pick the first one that matches. - routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank())); - - Self { routes } + /// Return the correct, protocol-specific "Not Found" response for an unknown operation. + fn unknown_operation(&self) -> RouterFuture { + let protocol = match &self.routes { + Routes::RestJson1(_) => Protocol::RestJson1, + Routes::RestXml(_) => Protocol::RestXml, + Routes::AwsJson10(_) => Protocol::AwsJson10, + Routes::AwsJson11(_) => Protocol::AwsJson11, + }; + let error = RuntimeError { + protocol, + kind: RuntimeErrorKind::UnknownOperation, + }; + RouterFuture::from_response(error.into_response()) } + /// Return the HTTP error response for non allowed method. + fn method_not_allowed(&self) -> RouterFuture { + RouterFuture::from_response({ + let mut res = Response::new(crate::body::empty()); + *res.status_mut() = StatusCode::METHOD_NOT_ALLOWED; + res + }) + } /// Convert this router into a [`MakeService`], that is a [`Service`] whose /// response is another service. /// @@ -124,12 +152,146 @@ where .layer_fn(Route::new) .layer(MapResponseBodyLayer::new(boxed)) .layer(layer); - let routes = self - .routes + match self.routes { + Routes::RestJson1(routes) => { + let routes = routes + .into_iter() + .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)) + .collect(); + Router { + routes: Routes::RestJson1(routes), + } + } + Routes::RestXml(routes) => { + let routes = routes + .into_iter() + .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)) + .collect(); + Router { + routes: Routes::RestXml(routes), + } + } + Routes::AwsJson10(routes) => { + let routes = routes + .into_iter() + .map(|(operation, route)| (operation, Layer::layer(&layer, route))) + .collect(); + Router { + routes: Routes::AwsJson10(routes), + } + } + Routes::AwsJson11(routes) => { + let routes = routes + .into_iter() + .map(|(operation, route)| (operation, Layer::layer(&layer, route))) + .collect(); + Router { + routes: Routes::AwsJson11(routes), + } + } + } + } + + /// Create a new RestJson1 `Router` from an iterator over pairs of [`RequestSpec`]s and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_rest_json_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + RequestSpec, + ), + >, + { + let mut routes: Vec<(Route, RequestSpec)> = routes + .into_iter() + .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec)) + .collect(); + + // Sort them once by specifity, with the more specific routes sorted before the less + // specific ones, so that when routing a request we can simply iterate through the routes + // and pick the first one that matches. + routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank())); + + Self { + routes: Routes::RestJson1(routes), + } + } + + /// Create a new RestXml `Router` from an iterator over pairs of [`RequestSpec`]s and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_rest_xml_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + RequestSpec, + ), + >, + { + let mut routes: Vec<(Route, RequestSpec)> = routes + .into_iter() + .map(|(svc, request_spec)| (Route::from_box_clone_service(svc), request_spec)) + .collect(); + + // Sort them once by specifity, with the more specific routes sorted before the less + // specific ones, so that when routing a request we can simply iterate through the routes + // and pick the first one that matches. + routes.sort_by_key(|(_route, request_spec)| std::cmp::Reverse(request_spec.rank())); + + Self { + routes: Routes::RestXml(routes), + } + } + + /// Create a new AwsJson 1.0 `Router` from an iterator over pairs of operation names and services. + /// + /// If the iterator is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_aws_json_10_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + String, + ), + >, + { + let routes = routes + .into_iter() + .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) + .collect(); + + Self { + routes: Routes::AwsJson10(routes), + } + } + + /// Create a new AwsJson 1.1 `Router` from a vector of pairs of operations and services. + /// + /// If the vector is empty the router will respond `404 Not Found` to all requests. + #[doc(hidden)] + pub fn new_aws_json_11_router(routes: T) -> Self + where + T: IntoIterator< + Item = ( + tower::util::BoxCloneService, Response, Infallible>, + String, + ), + >, + { + let routes = routes .into_iter() - .map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)) + .map(|(svc, operation)| (operation, Route::from_box_clone_service(svc))) .collect(); - Router { routes } + + Self { + routes: Routes::AwsJson11(routes), + } } } @@ -148,44 +310,84 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - let mut method_not_allowed = false; + match &self.routes { + // REST routes. + Routes::RestJson1(routes) | Routes::RestXml(routes) => { + let mut method_not_allowed = false; + + // Loop through all the routes and validate if any of them matches. Routes are already ranked. + for (route, request_spec) in routes { + match request_spec.matches(&req) { + request_spec::Match::Yes => { + return RouterFuture::from_oneshot(route.clone().oneshot(req)); + } + request_spec::Match::MethodNotAllowed => method_not_allowed = true, + // Continue looping to see if another route matches. + request_spec::Match::No => continue, + } + } - for (route, request_spec) in &self.routes { - match request_spec.matches(&req) { - request_spec::Match::Yes => { - return RouterFuture::from_oneshot(route.clone().oneshot(req)); + if method_not_allowed { + // The HTTP method is not correct. + self.method_not_allowed() + } else { + // In any other case return the `RuntimeError::UnknownOperation`. + self.unknown_operation() + } + } + // AwsJson routes. + Routes::AwsJson10(routes) | Routes::AwsJson11(routes) => { + if req.uri() == "/" { + // Check the request method for POST. + if req.method() == http::Method::POST { + // Find the `x-amz-target` header. + if let Some(target) = req.headers().get("x-amz-target") { + if let Ok(target) = target.to_str() { + // Lookup in the `HashMap` for a route for the target. + let route = routes.get(target); + if let Some(route) = route { + return RouterFuture::from_oneshot(route.clone().oneshot(req)); + } + } + } + } else { + // The HTTP method is not POST. + return self.method_not_allowed(); + } } - request_spec::Match::MethodNotAllowed => method_not_allowed = true, - // Continue looping to see if another route matches. - request_spec::Match::No => continue, + // In any other case return the `RuntimeError::UnknownOperation`. + self.unknown_operation() } } - - let status_code = if method_not_allowed { - StatusCode::METHOD_NOT_ALLOWED - } else { - StatusCode::NOT_FOUND - }; - RouterFuture::from_response( - Response::builder() - .status(status_code) - .body(crate::body::empty()) - .unwrap(), - ) } } #[cfg(test)] -mod tests { +mod rest_tests { use super::*; use crate::{body::boxed, routing::request_spec::*}; use futures_util::Future; - use http::Method; + use http::{HeaderMap, Method}; use std::pin::Pin; /// Helper function to build a `Request`. Used in other test modules. - pub fn req(method: &Method, uri: &str) -> Request<()> { - Request::builder().method(method).uri(uri).body(()).unwrap() + pub fn req(method: &Method, uri: &str, headers: Option) -> Request<()> { + let mut r = Request::builder().method(method).uri(uri).body(()).unwrap(); + if let Some(headers) = headers { + *r.headers_mut() = headers + } + r + } + + // Returns a `Response`'s body as a `String`, without consuming the response. + pub async fn get_body_as_string(res: &mut Response) -> String + where + B: http_body::Body + std::marker::Unpin, + B::Error: std::fmt::Debug, + { + let body_mut = res.body_mut(); + let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap(); + String::from(std::str::from_utf8(&body_bytes).unwrap()) } /// A service that returns its name and the request's URI in the response body. @@ -210,17 +412,6 @@ mod tests { } } - // Returns a `Response`'s body as a `String`, without consuming the response. - async fn get_body_as_str(res: &mut Response) -> String - where - B: http_body::Body + std::marker::Unpin, - B::Error: std::fmt::Debug, - { - let body_mut = res.body_mut(); - let body_bytes = hyper::body::to_bytes(body_mut).await.unwrap(); - String::from(std::str::from_utf8(&body_bytes).unwrap()) - } - // This test is a rewrite of `mux.spec.ts`. // https://github.com/awslabs/smithy-typescript/blob/fbf97a9bf4c1d8cf7f285ea7c24e1f0ef280142a/smithy-typescript-ssdk-libs/server-common/src/httpbinding/mux.spec.ts #[tokio::test] @@ -271,55 +462,64 @@ mod tests { ), ]; - let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| { + // Test both RestJson1 and RestXml routers. + let router_json = Router::new_rest_json_router(request_specs.clone().into_iter().map(|(spec, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), + spec, + ) + })); + let router_xml = Router::new_rest_xml_router(request_specs.into_iter().map(|(spec, svc_name)| { ( tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), spec, ) })); - let hits = vec![ - ("A", Method::GET, "/a/b/c"), - ("MiddleGreedy", Method::GET, "/mg/a/z"), - ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), - ("Delete", Method::DELETE, "/?foo=bar&baz"), - ("Delete", Method::DELETE, "/?foo=bar&baz=&"), - ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), - ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), - ]; - for (svc_name, method, uri) in &hits { - let mut res = router.call(req(method, uri)).await.unwrap(); - let actual_body = get_body_as_str(&mut res).await; - - assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); - } + for mut router in [router_json, router_xml] { + let hits = vec![ + ("A", Method::GET, "/a/b/c"), + ("MiddleGreedy", Method::GET, "/mg/a/z"), + ("MiddleGreedy", Method::GET, "/mg/a/b/c/d/z?abc=def"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux"), + ("Delete", Method::DELETE, "/?foo=bar&baz"), + ("Delete", Method::DELETE, "/?foo=bar&baz=&"), + ("Delete", Method::DELETE, "/?foo=bar&baz=quux&baz=grault"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=bar"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo"), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo="), + ("QueryKeyOnly", Method::POST, "/query_key_only?foo=&"), + ]; + for (svc_name, method, uri) in &hits { + let mut res = router.call(req(method, uri, None)).await.unwrap(); + let actual_body = get_body_as_string(&mut res).await; + + assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); + } - for (_, _, uri) in hits { - let res = router.call(req(&Method::PATCH, uri)).await.unwrap(); - assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status()); - } + for (_, _, uri) in hits { + let res = router.call(req(&Method::PATCH, uri, None)).await.unwrap(); + assert_eq!(StatusCode::METHOD_NOT_ALLOWED, res.status()); + } - let misses = vec![ - (Method::GET, "/a"), - (Method::GET, "/a/b"), - (Method::GET, "/mg"), - (Method::GET, "/mg/q"), - (Method::GET, "/mg/z"), - (Method::GET, "/mg/a/b/z/c"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?foo=bar"), - (Method::DELETE, "/?baz=quux"), - (Method::POST, "/query_key_only?baz=quux"), - (Method::GET, "/"), - (Method::POST, "/"), - ]; - for (method, miss) in misses { - let res = router.call(req(&method, miss)).await.unwrap(); - assert_eq!(StatusCode::NOT_FOUND, res.status()); + let misses = vec![ + (Method::GET, "/a"), + (Method::GET, "/a/b"), + (Method::GET, "/mg"), + (Method::GET, "/mg/q"), + (Method::GET, "/mg/z"), + (Method::GET, "/mg/a/b/z/c"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?foo=bar"), + (Method::DELETE, "/?baz=quux"), + (Method::POST, "/query_key_only?baz=quux"), + (Method::GET, "/"), + (Method::POST, "/"), + ]; + for (method, miss) in misses { + let res = router.call(req(&method, miss, None)).await.unwrap(); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + } } } @@ -364,7 +564,7 @@ mod tests { ), ]; - let mut router = Router::from_box_clone_service_iter(request_specs.into_iter().map(|(spec, svc_name)| { + let mut router = Router::new_rest_json_router(request_specs.into_iter().map(|(spec, svc_name)| { ( tower::util::BoxCloneService::new(NamedEchoUriService(String::from(svc_name))), spec, @@ -378,10 +578,96 @@ mod tests { ("B2", Method::GET, "/b/foo?q=baz"), ]; for (svc_name, method, uri) in &hits { - let mut res = router.call(req(method, uri)).await.unwrap(); - let actual_body = get_body_as_str(&mut res).await; + let mut res = router.call(req(method, uri, None)).await.unwrap(); + let actual_body = get_body_as_string(&mut res).await; assert_eq!(format!("{} :: {}", svc_name, uri), actual_body); } } } + +#[cfg(test)] +mod awsjson_tests { + use super::rest_tests::{get_body_as_string, req}; + use super::*; + use crate::body::boxed; + use futures_util::Future; + use http::{HeaderMap, HeaderValue, Method}; + use pretty_assertions::assert_eq; + use std::pin::Pin; + + /// A service that returns its name and the request's URI in the response body. + #[derive(Clone)] + struct NamedEchoOperationService(String); + + impl Service> for NamedEchoOperationService { + type Response = Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + let target = req + .headers() + .get("x-amz-target") + .map(|x| x.to_str().unwrap()) + .unwrap_or("unknown"); + let body = boxed(Body::from(format!("{} :: {}", self.0, target))); + let fut = async { Ok(Response::builder().status(&http::StatusCode::OK).body(body).unwrap()) }; + Box::pin(fut) + } + } + + #[tokio::test] + async fn simple_routing() { + let routes = vec![("Service.Operation", "A")]; + let router_json10 = Router::new_aws_json_10_router(routes.clone().into_iter().map(|(operation, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), + operation.to_string(), + ) + })); + let router_json11 = Router::new_aws_json_11_router(routes.into_iter().map(|(operation, svc_name)| { + ( + tower::util::BoxCloneService::new(NamedEchoOperationService(String::from(svc_name))), + operation.to_string(), + ) + })); + + for mut router in [router_json10, router_json11] { + let mut headers = HeaderMap::new(); + headers.insert("x-amz-target", HeaderValue::from_static("Service.Operation")); + + // Valid request, should return a valid body. + let mut res = router + .call(req(&Method::POST, "/", Some(headers.clone()))) + .await + .unwrap(); + let actual_body = get_body_as_string(&mut res).await; + assert_eq!(format!("{} :: {}", "A", "Service.Operation"), actual_body); + + // No headers, should return NOT_FOUND. + let res = router.call(req(&Method::POST, "/", None)).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // Wrong HTTP method, should return METHOD_NOT_ALLOWED. + let res = router + .call(req(&Method::GET, "/", Some(headers.clone()))) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + // Wrong URI, should return NOT_FOUND. + let res = router + .call(req(&Method::POST, "/something", Some(headers))) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs index da22366bd171934e4425c953bf3030a788e7a7f1..adfb02ffc09106efd3b30c9bbbe70d706790337d 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/request_spec.rs @@ -243,7 +243,7 @@ impl RequestSpec { #[cfg(test)] mod tests { - use super::super::tests::req; + use super::super::rest_tests::req; use super::*; use http::Method; @@ -295,7 +295,7 @@ mod tests { (Method::GET, "/mg/a/z/z/z"), ]; for (method, uri) in &hits { - assert_eq!(Match::Yes, spec.matches(&req(method, uri))); + assert_eq!(Match::Yes, spec.matches(&req(method, uri, None))); } } @@ -309,7 +309,7 @@ mod tests { (Method::DELETE, "/?foo&foo"), ]; for (method, uri) in &hits { - assert_eq!(Match::Yes, spec.matches(&req(method, uri))); + assert_eq!(Match::Yes, spec.matches(&req(method, uri, None))); } } @@ -325,7 +325,7 @@ mod tests { fn repeated_query_keys_same_values_match() { assert_eq!( Match::Yes, - key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar")) + key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=bar", None)) ); } @@ -333,7 +333,7 @@ mod tests { fn repeated_query_keys_distinct_values_does_not_match() { assert_eq!( Match::No, - key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz")) + key_value_spec().matches(&req(&Method::DELETE, "/?foo=bar&foo=baz", None)) ); } @@ -354,11 +354,11 @@ mod tests { #[test] fn empty_segments_in_the_middle_do_matter() { - assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b"))); + assert_eq!(Match::Yes, ab_spec().matches(&req(&Method::GET, "/a/b", None))); let misses = vec![(Method::GET, "/a//b"), (Method::GET, "//////a//b")]; for (method, uri) in &misses { - assert_eq!(Match::No, ab_spec().matches(&req(method, uri))); + assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None))); } } @@ -379,10 +379,10 @@ mod tests { (Method::GET, "/a//b"), // Label is bound to `""`. ]; for (method, uri) in &hits { - assert_eq!(Match::Yes, label_spec.matches(&req(method, uri))); + assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None))); } - assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b"))); + assert_eq!(Match::No, label_spec.matches(&req(&Method::GET, "/a///b", None))); } #[test] @@ -403,7 +403,7 @@ mod tests { (Method::GET, "/a///a//b///suffix"), ]; for (method, uri) in &hits { - assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri))); + assert_eq!(Match::Yes, greedy_label_spec.matches(&req(method, uri, None))); } } @@ -418,7 +418,7 @@ mod tests { (Method::GET, "//a//b////"), ]; for (method, uri) in &misses { - assert_eq!(Match::No, ab_spec().matches(&req(method, uri))); + assert_eq!(Match::No, ab_spec().matches(&req(method, uri, None))); } } @@ -432,13 +432,13 @@ mod tests { let misses = vec![(Method::GET, "/a"), (Method::GET, "/a//"), (Method::GET, "/a///")]; for (method, uri) in &misses { - assert_eq!(Match::No, label_spec.matches(&req(method, uri))); + assert_eq!(Match::No, label_spec.matches(&req(method, uri, None))); } // In the second example, the label is bound to `""`. let hits = vec![(Method::GET, "/a/label"), (Method::GET, "/a/")]; for (method, uri) in &hits { - assert_eq!(Match::Yes, label_spec.matches(&req(method, uri))); + assert_eq!(Match::Yes, label_spec.matches(&req(method, uri, None))); } } } diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 6397794b8fcbf4bab36e74fdd12ed939dd0bd172..9f6ae3f8049142980c930dd643361f1b48a5bd5c 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -25,7 +25,8 @@ use crate::protocols::Protocol; #[derive(Debug)] pub enum RuntimeErrorKind { - // UnknownOperation, + /// The requested operation does not exist. + UnknownOperation, /// Request failed to deserialize or response failed to serialize. Serialization(crate::Error), /// As of writing, this variant can only occur upon failure to extract an @@ -43,6 +44,7 @@ impl RuntimeErrorKind { match self { RuntimeErrorKind::Serialization(_) => "SerializationException", RuntimeErrorKind::InternalFailure(_) => "InternalFailureException", + RuntimeErrorKind::UnknownOperation => "UnknownOperation", } } } @@ -58,11 +60,16 @@ impl axum_core::response::IntoResponse for RuntimeError { let status_code = match self.kind { RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST, RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR, + RuntimeErrorKind::UnknownOperation => http::StatusCode::NOT_FOUND, }; let body = crate::body::to_boxed(match self.protocol { Protocol::RestJson1 => "{}", Protocol::RestXml => "", + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization + Protocol::AwsJson10 => "", + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + Protocol::AwsJson11 => "", }); let mut builder = http::Response::builder(); @@ -74,9 +81,9 @@ impl axum_core::response::IntoResponse for RuntimeError { .header("Content-Type", "application/json") .header("X-Amzn-Errortype", self.kind.name()); } - Protocol::RestXml => { - builder = builder.header("Content-Type", "application/xml"); - } + Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), + Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"), + Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"), } builder = builder.extension(crate::extension::RuntimeErrorExtension::new(String::from(