diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 64d2fd75477e3b5e2d9d729f4bdc494975a46543..73330149cff84f50b2bf121a270b15073fb7a67f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -19,6 +19,7 @@ object ServerCargoDependency { val AxumCore: CargoDependency = CargoDependency("axum-core", CratesIo("0.1")) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3")) val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2")) + val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index 1fa28c7a412798558edae0818d7404fdd13c828b..09aba7492b898092a636e7d47013f7dbc9739dff 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -7,16 +7,21 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.node.ExpectationNotMetException +import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate @@ -101,6 +106,7 @@ private class ServerHttpProtocolImplGenerator( "LazyStatic" to CargoDependency.LazyStatic.asType(), "PercentEncoding" to CargoDependency.PercentEncoding.asType(), "Regex" to CargoDependency.Regex.asType(), + "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(), "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(), "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig), "http" to RuntimeType.http, @@ -532,6 +538,7 @@ private class ServerHttpProtocolImplGenerator( } } serverRenderUriPathParser(this, operationShape) + serverRenderQueryStringParser(this, operationShape) val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) { "?" @@ -583,7 +590,6 @@ private class ServerHttpProtocolImplGenerator( pattern.append(it.content) } } - val errorShape = operationShape.errorSymbol(symbolProvider) with(writer) { rustTemplate( """ @@ -595,7 +601,7 @@ private class ServerHttpProtocolImplGenerator( ) rustBlock("if let Some(captures) = RE.captures(request.uri().path())") { pathBindings.forEach { - val deserializer = generateParseLabelFn(it) + val deserializer = generateParsePercentEncodedStrFn(it) rustTemplate( """ if let Some(m) = captures.name("${it.locationName}") { @@ -605,13 +611,164 @@ private class ServerHttpProtocolImplGenerator( } """.trimIndent(), "deserializer" to deserializer, - "E" to errorShape, ) } } } } + // The `httpQueryParams` trait can be applied to structure members that target: + // * a map of string, + // * a map of list of string; or + // * a map of set of string. + enum class QueryParamsTargetMapValueType { + STRING, LIST, SET; + + fun asRustType(): RustType = + when (this) { + STRING -> RustType.String + LIST -> RustType.Vec(RustType.String) + SET -> RustType.HashSet(RustType.String) + } + } + + private fun queryParamsTargetMapValueType(targetMapValue: Shape): QueryParamsTargetMapValueType = + if (targetMapValue.isStringShape) { + QueryParamsTargetMapValueType.STRING + } else if (targetMapValue.isListShape) { + QueryParamsTargetMapValueType.LIST + } else if (targetMapValue.isSetShape) { + QueryParamsTargetMapValueType.SET + } else { + throw ExpectationNotMetException(""" + @httpQueryParams trait applied to non-supported target + $targetMapValue of type ${targetMapValue.type} + """.trimIndent(), + targetMapValue.sourceLocation) + } + + private fun serverRenderQueryStringParser(writer: RustWriter, operationShape: OperationShape) { + val queryBindings = + httpBindingResolver.requestBindings(operationShape).filter { + it.location == HttpLocation.QUERY + } + // Only a single structure member can be bound to `httpQueryParams`, hence `find`. + val queryParamsBinding = + httpBindingResolver.requestBindings(operationShape).find { + it.location == HttpLocation.QUERY_PARAMS + } + if (queryBindings.isEmpty() && queryParamsBinding == null) { + return + } + + fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType { + check(this.location == HttpLocation.QUERY_PARAMS) + val queryParamsTarget = model.expectShape(this.member.target) + val mapTarget = queryParamsTarget.asMapShape().get() + return queryParamsTargetMapValueType(model.expectShape(mapTarget.value.target)) + } + + with(writer) { + rustTemplate(""" + let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?; + let pairs = #{SerdeUrlEncoded}::from_str::>(query_string)?; + """.trimIndent(), + *codegenScope + ) + + if (queryParamsBinding != null) { + rustTemplate("let mut query_params: #{HashMap} = #{HashMap}::new();", + "HashMap" to RustType.HashMap.RuntimeType, + ) + } + val (queryBindingsTargettingCollection, queryBindingsTargettingSimple) = + queryBindings.partition { model.expectShape(it.member.target) is CollectionShape } + queryBindingsTargettingSimple.forEach { + rust("let mut seen_${it.memberName.toSnakeCase()} = false;") + } + queryBindingsTargettingCollection.forEach { + rust("let mut ${it.memberName.toSnakeCase()} = Vec::new();") + } + + rustBlock("for (k, v) in pairs") { + queryBindingsTargettingSimple.forEach { + val deserializer = generateParsePercentEncodedStrFn(it) + rustTemplate(""" + if !seen_${it.memberName.toSnakeCase()} && k == "${it.locationName}" { + input = input.${it.member.setterName()}( + #{deserializer}(v)? + ); + seen_${it.memberName.toSnakeCase()} = true; + } + """.trimIndent(), + "deserializer" to deserializer) + } + queryBindingsTargettingCollection.forEach { + rustBlock("if k == ${it.locationName.dq()}") { + val targetCollectionShape = model.expectShape(it.member.target, CollectionShape::class.java) + val memberShape = model.expectShape(targetCollectionShape.member.target) + + when { + memberShape.isStringShape -> { + // `<_>::from()` is necessary to convert the `&str` into: + // * the Rust enum in case the `string` shape has the `enum` trait; or + // * `String` in case it doesn't. + rustTemplate(""" + let v = <_>::from(#{PercentEncoding}::percent_decode_str(v).decode_utf8()?.as_ref()); + """.trimIndent(), *codegenScope) + } + memberShape.isTimestampShape -> { + val index = HttpBindingIndex.of(model) + val timestampFormat = + index.determineTimestampFormat( + it.member, + it.location, + protocol.defaultTimestampFormat, + ) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + rustTemplate(""" + let v = #{PercentEncoding}::percent_decode_str(v).decode_utf8()?; + let v = #{DateTime}::from_str(&v, #{format})?; + """.trimIndent(), + *codegenScope, + "format" to timestampFormatType, + ) + } + else -> { // Number or boolean. + rust(""" + let v = <_ as #T>::parse_smithy_primitive(v)?; + """.trimIndent(), CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")) + } + } + rust("${it.memberName.toSnakeCase()}.push(v);") + } + } + + if (queryParamsBinding != null) { + when (queryParamsBinding.queryParamsBindingTargetMapValueType()) { + QueryParamsTargetMapValueType.STRING -> { + rust("query_params.entry(String::from(k)).or_insert_with(|| String::from(v));") + } else -> { + rustTemplate(""" + let entry = query_params.entry(String::from(k)).or_default(); + entry.push(String::from(v)); + """.trimIndent()) + } + } + } + } + if (queryParamsBinding != null) { + rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));") + } + queryBindingsTargettingCollection.forEach { + rustTemplate(""" + input = input.${it.member.setterName()}(Some(${it.memberName.toSnakeCase()})); + """.trimIndent()) + } + } + } + private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { val httpBindingGenerator = ResponseBindingGenerator( @@ -629,39 +786,43 @@ private class ServerHttpProtocolImplGenerator( ) } - private fun generateParseLabelFn(binding: HttpBindingDescriptor): RuntimeType { - check(binding.location == HttpLocation.LABEL) + private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType { + // HTTP bindings we support that contain percent-encoded data. + check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY) + val target = model.expectShape(binding.member.target) return when { - target.isStringShape -> generateParseLabelStringFn(binding) - target.isTimestampShape -> generateParseLabelTimestampFn(binding) - else -> generateParseLabelPrimitiveFn(binding) + target.isStringShape -> generateParsePercentEncodedStrAsStringFn(binding) + target.isTimestampShape -> generateParsePercentEncodedStrAsTimestampFn(binding) + else -> generateParseStrAsPrimitiveFn(binding) } } - private fun generateParseLabelStringFn(binding: HttpBindingDescriptor): RuntimeType { + private fun generateParsePercentEncodedStrAsStringFn(binding: HttpBindingDescriptor): RuntimeType { val output = symbolProvider.toSymbol(binding.member) - val fnName = generateParseLabelFnName(binding) + val fnName = generateParseStrFnName(binding) return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer -> writer.rustBlockTemplate( "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{SmithyRejection}>", *codegenScope, "O" to output, ) { - rustTemplate( - """ - let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?; - Ok(Some(value.into_owned())) - """.trimIndent(), + // `<_>::from()` is necessary to convert the `&str` into: + // * the Rust enum in case the `string` shape has the `enum` trait; or + // * `String` in case it doesn't. + rustTemplate(""" + let value = <_>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref()); + Ok(Some(value)) + """.trimIndent(), *codegenScope, ) } } } - private fun generateParseLabelTimestampFn(binding: HttpBindingDescriptor): RuntimeType { + private fun generateParsePercentEncodedStrAsTimestampFn(binding: HttpBindingDescriptor): RuntimeType { val output = symbolProvider.toSymbol(binding.member) - val fnName = generateParseLabelFnName(binding) + val fnName = generateParseStrFnName(binding) val index = HttpBindingIndex.of(model) val timestampFormat = index.determineTimestampFormat( @@ -689,9 +850,10 @@ private class ServerHttpProtocolImplGenerator( } } - private fun generateParseLabelPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType { + // TODO These functions can be replaced with the ones in https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html + private fun generateParseStrAsPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType { val output = symbolProvider.toSymbol(binding.member) - val fnName = generateParseLabelFnName(binding) + val fnName = generateParseStrFnName(binding) return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer -> writer.rustBlockTemplate( "pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{SmithyRejection}>", @@ -709,9 +871,9 @@ private class ServerHttpProtocolImplGenerator( } } - private fun generateParseLabelFnName(binding: HttpBindingDescriptor): String { + private fun generateParseStrFnName(binding: HttpBindingDescriptor): String { val containerName = binding.member.container.name.toSnakeCase() val memberName = binding.memberName.toSnakeCase() - return "parse_label_${containerName}_$memberName" + return "parse_str_${containerName}_$memberName" } } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 5e7a9a62046bbb31f3b4c7efef91b1772823e5fd..05630a5654da173cacb7971b904eba06f623c255 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -64,6 +64,13 @@ define_rejection! { pub struct MissingJsonContentType; } +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Expected query string in URI but none found"] + /// Rejection type used if the URI has no query string and we need to deserialize data from it. + pub struct MissingQueryString; +} + define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to parse request MIME type"] @@ -125,6 +132,7 @@ composite_rejection! { BodyAlreadyExtracted, HeadersAlreadyExtracted, ExtensionsAlreadyExtracted, + MissingQueryString, } } @@ -164,6 +172,12 @@ impl From for SmithyRejection { } } +impl From for SmithyRejection { + fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self { + SmithyRejection::Deserialize(Deserialize::from_err(err)) + } +} + impl From for SmithyRejection { fn from(err: aws_smithy_http::operation::SerializationError) -> Self { SmithyRejection::Serialize(Serialize::from_err(err)) @@ -193,3 +207,9 @@ impl From for SmithyRejection { SmithyRejection::HeadersParse(HeadersParse::from_err(err)) } } + +impl From for SmithyRejection { + fn from(err: serde_urlencoded::de::Error) -> Self { + SmithyRejection::Deserialize(Deserialize::from_err(err)) + } +}