From c049a37f8cba5f9bec2e96c28db83e7efb2edc53 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Mon, 29 Mar 2021 19:03:26 -0400 Subject: [PATCH] RestJson Response Parsing (#272) * restJson1 Header Deserialization Support This adds support for the `@httpHeader` trait during request deserialization. This is facilitated via a `Read` abstraction which allows parsers to consume part of an input and return the remaining input. We use this to enable things like `HttpDates` which actually include commas to be parsed in a comma delimited fashion. The internal libraries may eventually be refactored to reduce the amount of generics if it proves to be a compiler bottleneck. * Add support for parsing request bodies * Add response code parsing * Lots of refactorings to restJson response parsing * Fix some clippy lints * Add document deserialization support and delete the failing tests --- codegen-test/model/rest-json-extras.smithy | 6 +- .../smithy/rust/codegen/rustlang/RustTypes.kt | 1 + .../rust/codegen/smithy/RuntimeTypes.kt | 2 + .../rust/codegen/smithy/SymbolVisitor.kt | 2 +- .../smithy/generators/BuilderGenerator.kt | 11 + .../smithy/generators/EnumGenerator.kt | 14 ++ .../generators/HttpProtocolGenerator.kt | 1 + .../generators/HttpProtocolTestGenerator.kt | 5 +- .../RequestBindingGenerator.kt} | 12 +- .../http/ResponseBindingGenerator.kt | 208 ++++++++++++++++++ .../codegen/smithy/protocols/AwsJson10.kt | 14 +- .../codegen/smithy/protocols/AwsRestJson.kt | 152 +++++++++++-- .../rust/codegen/testutil/TestHelpers.kt | 26 ++- .../RequestBindingGeneratorTest.kt} | 10 +- .../http/ResponseBindingGeneratorTest.kt | 110 +++++++++ rust-runtime/smithy-http/src/header.rs | 145 ++++++++++++ rust-runtime/smithy-http/src/lib.rs | 1 + .../smithy-types/src/instant/format.rs | 53 +++++ rust-runtime/smithy-types/src/instant/mod.rs | 52 +++++ 19 files changed, 787 insertions(+), 38 deletions(-) rename codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/{HttpTraitBindingGenerator.kt => http/RequestBindingGenerator.kt} (96%) create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt rename codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/{HttpTraitBindingGeneratorTest.kt => http/RequestBindingGeneratorTest.kt} (97%) create mode 100644 codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt create mode 100644 rust-runtime/smithy-http/src/header.rs diff --git a/codegen-test/model/rest-json-extras.smithy b/codegen-test/model/rest-json-extras.smithy index c06463472..36f1eaaa4 100644 --- a/codegen-test/model/rest-json-extras.smithy +++ b/codegen-test/model/rest-json-extras.smithy @@ -27,7 +27,8 @@ service RestJsonExtras { } ]) operation EnumPayload { - input: EnumPayloadInput + input: EnumPayloadInput, + output: EnumPayloadInput } structure EnumPayloadInput { @@ -50,7 +51,8 @@ string StringEnum } ]) operation StringPayload { - input: StringPayloadInput + input: StringPayloadInput, + output: StringPayloadInput } structure StringPayloadInput { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt index ae102f28e..56775d099 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt @@ -202,6 +202,7 @@ sealed class Attribute { */ val NonExhaustive = Custom("non_exhaustive") val AllowUnused = Custom("allow(dead_code)") + val AllowUnusedMut = Custom("allow(unused_mut)") } data class Derives(val derives: Set) : Attribute() { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index 863412d58..73527e08d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.rustlang.InlineDependency import software.amazon.smithy.rust.codegen.rustlang.RustDependency import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType import java.util.Optional data class RuntimeConfig(val cratePrefix: String = "smithy", val relativePath: String = "../") { @@ -125,6 +126,7 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n func, CargoDependency.ProtocolTestHelpers(runtimeConfig), "protocol_test_helpers" ) + val http = CargoDependency.Http.asType() fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt index 723be2fad..11901d80a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt @@ -143,7 +143,7 @@ class SymbolVisitor( // an Input shape, then the field is _not optional_. val httpLabeledInput = container.hasTrait(SyntheticInputTrait::class.java) && member.hasTrait(HttpLabelTrait::class.java) - return if (nullableIndex.isNullable(member) && !httpLabeledInput) { + return if (nullableIndex.isNullable(member) && !httpLabeledInput || model.expectShape(member.target).isDocumentShape) { symbol.makeOptional() } else symbol } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt index 63b7855f1..c1622b0e2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.defaultValue import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.makeOptional import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq @@ -128,6 +129,9 @@ class OperationInputBuilderGenerator( } } +/** setter names will never hit a reserved word and therefore never need escaping */ +fun MemberShape.setterName(): String = "set_${this.memberName.toSnakeCase()}" + abstract class BuilderGenerator( val model: Model, private val symbolProvider: RustSymbolProvider, @@ -194,6 +198,13 @@ abstract class BuilderGenerator( write("self.$memberName = Some(${builderConverter(coreType)});") write("self") } + + writer.rustBlock("pub fn ${member.setterName()}(mut self, inp: ${outerType.render(true)}) -> Self") { + val v = "inp".letIf(outerType !is RustType.Option) { + "Some($it)" + } + rust("self.$memberName = $v; self") + } } buildFn(this) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt index f5a5fbea8..3e6f11874 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/EnumGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock @@ -141,5 +142,18 @@ class EnumGenerator( } } } + + writer.rust( + """ + impl std::str::FromStr for $enumName { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> Result { + Ok($enumName::from(s)) + } + } + + """ + ) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt index 5209be99a..8b519393c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt @@ -147,6 +147,7 @@ abstract class HttpProtocolGenerator( operationShape: OperationShape, block: RustWriter.() -> Unit ) { + Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter) implBlockWriter.rustBlock( "fn from_response(response: &#T>) -> Result<#T, #T>", RuntimeType.Http("response::Response"), diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt index 55492f2a8..eab8af1bf 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -363,7 +363,10 @@ class HttpProtocolTestGenerator( FailingTest(AwsJson11, "AwsJson11EndpointTrait", Action.Request), FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", Action.Request), FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request), - FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request) + FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request), + + FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response), + FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Response), ) private val RunOnly: Set? = null diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt similarity index 96% rename from codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt rename to codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt index bed0d5a92..2f256b000 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt @@ -3,9 +3,8 @@ * SPDX-License-Identifier: Apache-2.0. */ -package software.amazon.smithy.rust.codegen.smithy.generators +package software.amazon.smithy.rust.codegen.smithy.generators.http -import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -26,6 +25,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary +import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember @@ -54,9 +57,9 @@ fun HttpTrait.uriFormatString(): String { * TODO: httpPrefixHeaders; 4h * TODO: Deserialization of all fields; 1w */ -class HttpTraitBindingGenerator( +class RequestBindingGenerator( val model: Model, - private val symbolProvider: SymbolProvider, + private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig, private val writer: RustWriter, private val shape: OperationShape, @@ -67,6 +70,7 @@ class HttpTraitBindingGenerator( private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS private val index = HttpBindingIndex.of(model) private val buildError = runtimeConfig.operationBuildError() + private val instant = RuntimeType.Instant(runtimeConfig).toSymbol().rustType() /** * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt new file mode 100644 index 000000000..9158d2d09 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt @@ -0,0 +1,208 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators.http + +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.MediaTypeTrait +import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.render +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.rustType +import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.toSnakeCase + +class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val operationShape: OperationShape) { + private val runtimeConfig = protocolConfig.runtimeConfig + private val symbolProvider = protocolConfig.symbolProvider + private val model = protocolConfig.model + private val index = HttpBindingIndex.of(model) + private val headerUtil = CargoDependency.SmithyHttp(runtimeConfig).asType().member("header") + private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS + private val instant = RuntimeType.Instant(runtimeConfig).toSymbol().rustType() + + /** + * Generate a function to deserialize [binding] from HTTP headers + * + * The name of the resulting function is returned as a String + * + * Generates a function like: + * ```rust + * fn parse_foo(headers: &http::HeaderMap) -> Result, ParseError> { + * ... + * } + * ``` + */ + fun generateDeserializeHeaderFn(binding: HttpBinding, writer: RustWriter): String { + val outputT = symbolProvider.toSymbol(binding.member) + val fnName = "parse_from_header_${binding.memberName.toSnakeCase()}" + writer.rustBlock( + "fn $fnName(headers: &#T::HeaderMap) -> Result<#T, #T::ParseError>", + RuntimeType.http, + outputT, + headerUtil + ) { + deserializeFromHeader(binding) + } + return fnName + } + + /** + * Generate a function to deserialize `[binding]` from the response payload + */ + fun generateDeserializePayloadFn( + binding: HttpBinding, + errorT: RuntimeType, + rustWriter: RustWriter, + // Deserialize a single structure or union member marked as a payload + structuredHandler: RustWriter.(String) -> Unit, + // Deserialize a document type marked as a payload + docHandler: RustWriter.(String) -> Unit + ): String { + check(binding.location == HttpBinding.Location.PAYLOAD) + val outputT = symbolProvider.toSymbol(binding.member) + val fnName = "parse_from_payload_${binding.memberName.toSnakeCase()}" + rustWriter.rustBlock("fn $fnName(body: &[u8]) -> Result<#T, #T>", outputT, errorT) { + deserializePayloadBody(binding, errorT, structuredHandler = structuredHandler, docShapeHandler = docHandler) + } + return fnName + } + + private fun RustWriter.deserializePayloadBody( + binding: HttpBinding, + errorSymbol: RuntimeType, + structuredHandler: RustWriter.(String) -> Unit, + docShapeHandler: RustWriter.(String) -> Unit + ) { + val member = binding.member + val targetShape = model.expectShape(member.target) + // There is an unfortunate bit of dual behavior caused by an empty body causing the output to be `None` instead + // of an empty instance of the response type. + withBlock("(!body.is_empty()).then(||{", "}).transpose()") { + when (targetShape) { + is StructureShape, is UnionShape -> this.structuredHandler("body") + is StringShape -> { + rustTemplate( + "let body_str = std::str::from_utf8(&body).map_err(#{error_symbol}::unhandled)?;", + "error_symbol" to errorSymbol + ) + if (targetShape.hasTrait(EnumTrait::class.java)) { + rust( + "Ok(#T::from(body_str))", + symbolProvider.toSymbol(targetShape) + ) + } else { + rust("Ok(body_str.to_string())") + } + } + is BlobShape -> rust( + "Ok(#T::new(body))", + RuntimeType.Blob(runtimeConfig) + ) + is DocumentShape -> this.docShapeHandler("body") + else -> TODO("unexpected shape: $targetShape") + } + } + } + + /** Parse a value from a header + * This function produces an expression which produces the precise output type required by the output shape + */ + private fun RustWriter.deserializeFromHeader(binding: HttpBinding) { + check(binding.location == HttpBinding.Location.HEADER) + val targetType = model.expectShape(binding.member.target) + val rustType = symbolProvider.toSymbol(targetType).rustType().stripOuter() + val (coreType, coreShape) = if (targetType is CollectionShape) { + rustType.stripOuter() to model.expectShape(targetType.member.target) + } else { + rustType to targetType + } + val parsedValue = safeName() + if (coreType == instant) { + val timestampFormat = + index.determineTimestampFormat( + binding.member, + HttpBinding.Location.HEADER, + defaultTimestampFormat + ) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + rust( + "let $parsedValue: Vec<${coreType.render(true)}> = #T::many_dates(&headers, ${binding.locationName.dq()}, #T)?;", + headerUtil, + timestampFormatType + ) + } else { + rust( + "let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(&headers, ${binding.locationName.dq()})?;", + headerUtil + ) + if (coreShape.hasTrait(MediaTypeTrait::class.java)) { + rustTemplate( + """let $parsedValue: Result, _> = $parsedValue + .iter().map(|s| + #{base_64_decode}(s).map_err(|_|#{header}::ParseError) + .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError)) + ).collect();""", + "base_64_decode" to RuntimeType.Base64Decode(runtimeConfig), + "header" to headerUtil + ) + rust("let $parsedValue = $parsedValue?;") + } + } + when (rustType) { + is RustType.Vec -> + rust( + """ + Ok(if !$parsedValue.is_empty() { + Some($parsedValue) + } else { + None + }) + """ + ) + is RustType.HashSet -> + rust( + """ + Ok(if !$parsedValue.is_empty() { + Some($parsedValue.into_iter().collect()) + } else { + None + }) + """ + ) + else -> + rustTemplate( + """ + if $parsedValue.len() > 1 { + Err(#{header_util}::ParseError) + } else { + let mut $parsedValue = $parsedValue; + Ok($parsedValue.pop()) + } + """, + "header_util" to headerUtil + ) + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt index 9c0ed321d..b8458adb8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt @@ -24,12 +24,14 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.Serializers import software.amazon.smithy.rust.codegen.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.generators.errorSymbol import software.amazon.smithy.rust.codegen.smithy.locatedIn +import software.amazon.smithy.rust.codegen.smithy.meta import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait @@ -110,9 +112,15 @@ class SyntheticBodySymbolProvider(private val model: Model, private val base: Ru override fun toSymbol(shape: Shape): Symbol { val initialSymbol = base.toSymbol(shape) val override = when (shape) { - is StructureShape -> if (shape.hasTrait(InputBodyTrait::class.java) || shape.hasTrait(OutputBodyTrait::class.java)) { - initialSymbol.toBuilder().locatedIn(Serializers).build() - } else null + is StructureShape -> when { + shape.hasTrait(InputBodyTrait::class.java) -> + initialSymbol.toBuilder().locatedIn(Serializers).build() + shape.hasTrait(OutputBodyTrait::class.java) -> + initialSymbol.toBuilder().locatedIn(Serializers).meta( + initialSymbol.expectRustMetadata().withDerives(RuntimeType("Default", null, "std::default")) + ).build() + else -> null + } is MemberShape -> { val container = model.expectShape(shape.container) if (container.hasTrait(InputBodyTrait::class.java)) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt index 701fbb740..36e24eca2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.smithy.protocols +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -18,7 +19,10 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.rustlang.Attribute +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate @@ -26,14 +30,22 @@ import software.amazon.smithy.rust.codegen.rustlang.writable import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator -import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport +import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.setterName import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember +import software.amazon.smithy.rust.codegen.util.outputShape +import java.util.logging.Logger class AwsRestJsonFactory : ProtocolGeneratorFactory { override fun buildProtocolGenerator( @@ -41,30 +53,28 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory { ): AwsRestJsonGenerator = AwsRestJsonGenerator(protocolConfig) /** Create a synthetic awsJsonInputBody if specified - * A body is created iff no member of [input] is targeted with the `PAYLOAD` trait. If a member is targeted with - * the payload trait, we don't need to create an input body. + * A body is created if any member of [shape] is bound to the `DOCUMENT` section of the `bindings. */ - private fun awsJsonInputBody(model: Model, operation: OperationShape, input: StructureShape?): StructureShape? { - if (input == null) { + private fun restJsonBody(shape: StructureShape?, bindings: Map): StructureShape? { + if (shape == null) { return null } - val bindingIndex = HttpBindingIndex.of(model) - val bindings: MutableMap = bindingIndex.getRequestBindings(operation) - val bodyMembers = input.members().filter { member -> + val bodyMembers = shape.members().filter { member -> bindings[member.memberName]?.location == HttpBinding.Location.DOCUMENT } return if (bodyMembers.isNotEmpty()) { - input.toBuilder().members(bodyMembers).build() + shape.toBuilder().members(bodyMembers).build() } else { null } } override fun transformModel(model: Model): Model { + val httpIndex = HttpBindingIndex.of(model) return OperationNormalizer(model).transformModel( - inputBodyFactory = { op, input -> awsJsonInputBody(model, op, input) }, - outputBodyFactory = OperationNormalizer.NoBody + inputBodyFactory = { op, input -> restJsonBody(input, httpIndex.getRequestBindings(op)) }, + outputBodyFactory = { op, output -> restJsonBody(output, httpIndex.getResponseBindings(op)) }, ) } @@ -72,7 +82,7 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory { // TODO: Support body for RestJson return ProtocolSupport( requestBodySerialization = true, - responseDeserialization = false, + responseDeserialization = true, errorDeserialization = false ) } @@ -90,6 +100,7 @@ class AwsRestJsonGenerator( private val protocolConfig: ProtocolConfig ) : HttpProtocolGenerator(protocolConfig) { // restJson1 requires all operations to use the HTTP trait + private val logger = Logger.getLogger(javaClass.name) private val model = protocolConfig.model override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) { @@ -97,17 +108,120 @@ class AwsRestJsonGenerator( } override fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) { + val outputShape = operationShape.outputShape(model) + val bodyId = outputShape.expectTrait(SyntheticOutputTrait::class.java).body + val bodyShape = bodyId?.let { model.expectShape(bodyId, StructureShape::class.java) } + val errorSymbol = operationShape.errorSymbol(symbolProvider) + val httpBindingGenerator = ResponseBindingGenerator(protocolConfig, operationShape) + + // Parsing the response works in two phases: + // 1. Code generate "parse_xyz" methods for each field + // 2. Code generate a parse_response method which utilizes the parse_xyz methods to set fields on a builder + val parseFunctions: Map = + renderParseFunctions(operationShape, httpBindingGenerator, implBlockWriter, bodyShape) + fromResponseFun(implBlockWriter, operationShape) { // avoid non-usage warnings - rust( - """ - let _ = response; - todo!() - """ - ) + Attribute.AllowUnusedMut.render(this) + rust("let mut output = #T::default();", outputShape.builderSymbol(symbolProvider)) + rust("let _ = response;") + if (bodyShape != null) { + rustTemplate( + """ + let body_slice = response.body().as_ref(); + + let parsed_body: #{body} = if body_slice.is_empty() { + #{body}::default() + } else { + #{from_slice}(response.body().as_ref()).map_err(#{err_symbol}::unhandled)? + }; + """, + "body" to symbolProvider.toSymbol(bodyShape), + "from_slice" to RuntimeType.SerdeJson("from_slice"), + "err_symbol" to errorSymbol + ) + } + outputShape.members().forEach { member -> + val parsedValue = parseFunctions[member.memberName] + ?: throw CodegenException("No parser defined for $member!. This is a bug") + // can delete when we don't have `todo!()` here anymore + Attribute.Custom("allow(unreachable_code, clippy::diverging_sub_expression)").render(this) + rust("{ output = output.${member.setterName()}($parsedValue); }") + } + + val err = if (StructureGenerator.fallibleBuilder(outputShape, symbolProvider)) { + ".map_err(|s|${format(errorSymbol)}::unhandled(s))?" + } else "" + rust("Ok(output.build()$err)") } } + /** + * Generate a parser & a parsed value converter for each output member of `operationShape` + * + * Returns a map with key = memberName, value = parsedValue + */ + private fun renderParseFunctions( + operationShape: OperationShape, + httpBindingGenerator: ResponseBindingGenerator, + implBlockWriter: RustWriter, + bodyShape: StructureShape? + ): Map { + val bindings = httpIndex.getResponseBindings(operationShape) + val outputShape = operationShape.outputShape(model) + val errorSymbol = operationShape.errorSymbol(symbolProvider) + return outputShape.members().map { member -> + val binding = bindings[member.memberName] ?: throw CodegenException("Binding should be defined") + member.memberName to when (binding.location) { + HttpBinding.Location.HEADER -> { + val fnName = httpBindingGenerator.generateDeserializeHeaderFn(binding, implBlockWriter) + """ + Self::$fnName( + response.headers() + ).map_err(|_|${implBlockWriter.format(errorSymbol)}::unhandled("Failed to parse ${member.memberName} from header `${binding.locationName}"))? + """ + } + HttpBinding.Location.DOCUMENT -> { + check(bodyShape != null) { + "$bodyShape was null but a member specified document bindings. This is a bug." + } + // When there is a subset of fields present as the body of the response, we will create a variable + // named `parsed_body`. Copy the field from parsed_body into the builder + + "parsed_body.${symbolProvider.toMemberName(member)}" + } + HttpBinding.Location.PAYLOAD -> { + val docShapeHandler: RustWriter.(String) -> Unit = { body -> + rustTemplate( + """ + #{serde_json}::from_slice::<#{doc_json}::DeserDoc>($body).map(|d|d.0).map_err(#{error_symbol}::unhandled) + """, + "doc_json" to RuntimeType.DocJson, + "serde_json" to CargoDependency.SerdeJson.asType(), + "error_symbol" to errorSymbol + ) + } + val structureShapeHandler: RustWriter.(String) -> Unit = { body -> + rust("#T($body).map_err(#T::unhandled)", RuntimeType.SerdeJson("from_slice"), errorSymbol) + } + val fnName = httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + implBlockWriter, + docHandler = docShapeHandler, + structuredHandler = structureShapeHandler + ) + "Self::$fnName(response.body().as_ref())?" + } + HttpBinding.Location.RESPONSE_CODE -> "Some(response.status().as_u16() as _)" + else -> { + logger.warning("Unhandled response binding type: ${binding.location}") + "todo!()" + } + } + }.toMap() + } + private fun serializeViaSyntheticBody( implBlockWriter: RustWriter, inputBody: StructureShape @@ -211,7 +325,7 @@ class AwsRestJsonGenerator( ) { val httpTrait = operationShape.expectTrait(HttpTrait::class.java) - val httpBindingGenerator = HttpTraitBindingGenerator( + val httpBindingGenerator = RequestBindingGenerator( model, symbolProvider, runtimeConfig, diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt index 4ac6c99b8..5d4e520d8 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/testutil/TestHelpers.kt @@ -6,6 +6,8 @@ package software.amazon.smithy.rust.codegen.testutil import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.smithy.CodegenConfig @@ -14,6 +16,7 @@ import software.amazon.smithy.rust.codegen.smithy.RustCodegenPlugin import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.smithy.generators.ModelBuilderGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.smithy.letIf @@ -21,13 +24,30 @@ import software.amazon.smithy.rust.codegen.util.dq import java.io.File val TestRuntimeConfig = RuntimeConfig(relativePath = File("../rust-runtime/").absolutePath) -val TestSymbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = TestRuntimeConfig, codegenConfig = CodegenConfig(), handleOptionality = true, handleRustBoxing = true) -fun testSymbolProvider(model: Model): RustSymbolProvider = RustCodegenPlugin.BaseSymbolProvider(model, TestSymbolVisitorConfig) +val TestSymbolVisitorConfig = SymbolVisitorConfig( + runtimeConfig = TestRuntimeConfig, + codegenConfig = CodegenConfig(), + handleOptionality = true, + handleRustBoxing = true +) + +fun testSymbolProvider(model: Model): RustSymbolProvider = + RustCodegenPlugin.BaseSymbolProvider(model, TestSymbolVisitorConfig) + +fun testProtocolConfig(model: Model): ProtocolConfig = ProtocolConfig( + model, + testSymbolProvider(model), + TestRuntimeConfig, + ServiceShape.builder().version("test").id("test#Service").build(), + ShapeId.from("test#Protocol"), + "test" +) private const val SmithyVersion = "1.0" fun String.asSmithyModel(sourceLocation: String? = null): Model { val processed = letIf(!this.startsWith("\$version")) { "\$version: ${SmithyVersion.dq()}\n$it" } - return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble().unwrap() + return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble() + .unwrap() } /** diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt similarity index 97% rename from codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt rename to codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt index 350649c59..c83eba1c6 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/RequestBindingGeneratorTest.kt @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ -package software.amazon.smithy.rust.codegen.generators +package software.amazon.smithy.rust.codegen.generators.http import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test @@ -14,9 +14,9 @@ import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.http.uriFormatString import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError -import software.amazon.smithy.rust.codegen.smithy.generators.uriFormatString import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.testutil.asSmithyModel @@ -25,7 +25,7 @@ import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.util.dq -class HttpTraitBindingGeneratorTest { +class RequestBindingGeneratorTest { private val baseModel = """ namespace smithy.example @@ -113,7 +113,7 @@ class HttpTraitBindingGeneratorTest { inputShape.renderWithModelBuilder(model, symbolProvider, writer) val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java) writer.rustBlock("impl PutObjectInput") { - HttpTraitBindingGenerator( + RequestBindingGenerator( model, symbolProvider, TestRuntimeConfig, writer, operationShape, inputShape, httpTrait diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt new file mode 100644 index 000000000..f557c6227 --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/http/ResponseBindingGeneratorTest.kt @@ -0,0 +1,110 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.generators.http + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder +import software.amazon.smithy.rust.codegen.testutil.testProtocolConfig +import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.testutil.unitTest +import software.amazon.smithy.rust.codegen.util.outputShape + +class ResponseBindingGeneratorTest { + private val baseModel = """ + namespace smithy.example + + @idempotent + @http(method: "PUT", uri: "/", code: 200) + operation PutObject { + output: PutObjectResponse + } + + list Extras { + member: Integer + } + + list Dates { + member: Timestamp + } + + @mediaType("video/quicktime") + string Video + + structure PutObjectResponse { + // Sent in the X-Dates header + @httpHeader("X-Dates") + dateHeaderList: Dates, + + @httpHeader("X-Ints") + intList: Extras, + + @httpHeader("X-MediaType") + mediaType: Video, + + // Sent in the body + data: Blob, + + // Sent in the body + additional: String, + } + """.asSmithyModel() + private val model = OperationNormalizer(baseModel).transformModel( + inputBodyFactory = OperationNormalizer.NoBody, + outputBodyFactory = OperationNormalizer.NoBody + ) + private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java) + private val symbolProvider = testSymbolProvider(model) + private val testProtocolConfig: ProtocolConfig = testProtocolConfig(model) + + private fun RustWriter.renderOperation() { + operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) + rustBlock("impl PutObjectOutput") { + val bindings = HttpBindingIndex.of(model).getResponseBindings(operationShape, HttpBinding.Location.HEADER) + bindings.forEach { binding -> + + ResponseBindingGenerator( + testProtocolConfig, operationShape + ).generateDeserializeHeaderFn(binding, this) + } + } + } + + @Test + fun deserializeHeadersIntoOutputShape() { + val testProject = TestWorkspace.testProject(symbolProvider) + testProject.withModule(RustModule.default("output", public = true)) { + it.renderOperation() + it.unitTest( + """ + let resp = http::Response::builder() + .header("X-Ints", "1,2,3") + .header("X-Ints", "4,5,6") + .header("X-MediaType", "c21pdGh5LXJz") + .header("X-Dates", "Mon, 16 Dec 2019 23:48:18 GMT") + .header("X-Dates", "Mon, 16 Dec 2019 23:48:18 GMT,Tue, 17 Dec 2019 23:48:18 GMT") + .body(()).expect("valid request"); + assert_eq!(PutObjectOutput::parse_from_header_int_list(&resp.headers()).unwrap(), Some(vec![1,2,3,4,5,6])); + assert_eq!(PutObjectOutput::parse_from_header_media_type(&resp.headers()).expect("valid").unwrap(), "smithy-rs"); + assert_eq!(PutObjectOutput::parse_from_header_date_header_list(&resp.headers()).unwrap().unwrap().len(), 3); + """ + ) + } + testProject.compileAndTest() + } +} diff --git a/rust-runtime/smithy-http/src/header.rs b/rust-runtime/smithy-http/src/header.rs new file mode 100644 index 000000000..a3ced255e --- /dev/null +++ b/rust-runtime/smithy-http/src/header.rs @@ -0,0 +1,145 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +//! Utilities for parsing information from headers + +use smithy_types::instant::Format; +use smithy_types::Instant; +use std::str::FromStr; +use std::fmt::{Display, Formatter}; +use std::fmt; +use std::error::Error; + +#[derive(Debug)] +pub struct ParseError; + +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Output failed to parse in headers") + } +} + +impl Error for ParseError {} + +/// Read all the dates from the header map at `key` according the `format` +/// +/// This is separate from `read_many` below because we need to invoke `Instant::read` to take advantage +/// of comma-aware parsing +pub fn many_dates( + headers: &http::HeaderMap, + key: &str, + format: Format, +) -> Result, ParseError> { + let mut out = vec![]; + for header in headers.get_all(key).iter() { + let mut header = header.to_str().map_err(|_| ParseError)?; + while !header.is_empty() { + let (v, next) = Instant::read(header, format, ',').map_err(|_| ParseError)?; + out.push(v); + header = next; + } + } + Ok(out) +} + +/// Read many comma / header delimited values from HTTP headers for `FromStr` types +pub fn read_many(headers: &http::HeaderMap, key: &str) -> Result, ParseError> +where + T: FromStr, +{ + let mut out = vec![]; + for header in headers.get_all(key).iter() { + let mut header = header.as_bytes(); + while !header.is_empty() { + let (v, next) = read_one::(&header)?; + out.push(v); + header = next; + } + } + Ok(out) +} + +/// Read one comma delimited value for `FromStr` types +pub fn read_one(s: &[u8]) -> Result<(T, &[u8]), ParseError> +where + T: FromStr, +{ + let (head, rest) = split_at_delim(s); + let head = std::str::from_utf8(head).map_err(|_| ParseError)?; + Ok((T::from_str(head.trim()).map_err(|_| ParseError)?, rest)) +} + +fn split_at_delim(s: &[u8]) -> (&[u8], &[u8]) { + let next_delim = s.iter().position(|b| b == &b',').unwrap_or(s.len()); + let (first, next) = s.split_at(next_delim); + (first, then_delim(next).unwrap()) +} + +fn then_delim(s: &[u8]) -> Result<&[u8], ParseError> { + if s.is_empty() { + Ok(&s) + } else if s.starts_with(b",") { + Ok(&s[1..]) + } else { + Err(ParseError) + } +} + +#[cfg(test)] +mod test { + use crate::header::read_many; + + #[test] + fn read_many_bools() { + let test_request = http::Request::builder() + .header("X-Bool-Multi", "true,false") + .header("X-Bool-Multi", "true") + .header("X-Bool", "true") + .header("X-Bool-Invalid", "truth,falsy") + .header("X-Bool-Single", "true,false,true,true") + .body(()) + .unwrap(); + assert_eq!( + read_many::(test_request.headers(), "X-Bool-Multi").expect("valid"), + vec![true, false, true] + ); + + assert_eq!( + read_many::(test_request.headers(), "X-Bool").unwrap(), + vec![true] + ); + assert_eq!( + read_many::(test_request.headers(), "X-Bool-Single").unwrap(), + vec![true, false, true, true] + ); + read_many::(test_request.headers(), "X-Bool-Invalid").expect_err("invalid"); + } + + #[test] + fn read_many_u16() { + let test_request = http::Request::builder() + .header("X-Multi", "123,456") + .header("X-Multi", "789") + .header("X-Num", "777") + .header("X-Num-Invalid", "12ef3") + .header("X-Num-Single", "1,2,3,4,5") + .body(()) + .unwrap(); + assert_eq!( + read_many::(test_request.headers(), "X-Multi").expect("valid"), + vec![123, 456, 789] + ); + + assert_eq!( + read_many::(test_request.headers(), "X-Num").unwrap(), + vec![777] + ); + assert_eq!( + read_many::(test_request.headers(), "X-Num-Single").unwrap(), + vec![1, 2, 3, 4, 5] + ); + read_many::(test_request.headers(), "X-Num-Invalid").expect_err("invalid"); + } +} diff --git a/rust-runtime/smithy-http/src/lib.rs b/rust-runtime/smithy-http/src/lib.rs index 1465a36dd..67e592e27 100644 --- a/rust-runtime/smithy-http/src/lib.rs +++ b/rust-runtime/smithy-http/src/lib.rs @@ -6,6 +6,7 @@ pub mod base64; pub mod body; pub mod endpoint; +pub mod header; pub mod label; pub mod middleware; pub mod operation; diff --git a/rust-runtime/smithy-types/src/instant/format.rs b/rust-runtime/smithy-types/src/instant/format.rs index 30e3862c6..9fd946108 100644 --- a/rust-runtime/smithy-types/src/instant/format.rs +++ b/rust-runtime/smithy-types/src/instant/format.rs @@ -144,6 +144,26 @@ pub mod http_date { parse_imf_fixdate(x) } + pub fn read(s: &str) -> Result<(Instant, &str), DateParseError> { + if !s.is_ascii() { + return Err(DateParseError::Invalid("Date must be valid ascii")); + } + let (first_date, rest) = match find_subsequence(s.as_bytes(), b" GMT") { + // split_at is correct because we asserted that this date is only valid ASCII so the byte index is + // the same as the char index + Some(idx) => s.split_at(idx), + None => return Err(DateParseError::Invalid("Date did not end in GMT")), + }; + Ok((parse(first_date)?, rest)) + } + + fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option { + haystack + .windows(needle.len()) + .position(|window| window == needle) + .map(|idx| idx + needle.len()) + } + fn parse_imf_fixdate(s: &[u8]) -> Result { // Example: `Sun, 06 Nov 1994 08:49:37 GMT` if s.len() < 29 @@ -289,14 +309,27 @@ mod test { ); } + #[test] + fn read_date() { + let fractional = "Mon, 16 Dec 2019 23:48:18.123 GMT,some more stuff"; + let ts = 1576540098; + let expected = Instant::from_fractional_seconds(ts, 0.123); + let (actual, rest) = http_date::read(fractional).expect("valid"); + assert_eq!(rest, ",some more stuff"); + assert_eq!(expected, actual); + http_date::read(rest).expect_err("invalid date"); + } + #[track_caller] fn check_roundtrip(epoch_secs: i64, subsecond_nanos: u32) { let instant = Instant::from_secs_and_nanos(epoch_secs, subsecond_nanos); let formatted = http_date::format(&instant); let parsed = http_date::parse(&formatted); + let read = http_date::read(&formatted); match parsed { Err(failure) => panic!("Date failed to parse {:?}", failure), Ok(date) => { + assert!(read.is_ok()); if date.subsecond_nanos != subsecond_nanos { assert_eq!(http_date::format(&instant), formatted); } else { @@ -330,6 +363,19 @@ mod test { assert_eq!(iso_8601::parse(date), Ok(expected)); } + #[test] + fn read_iso_date_comma_split() { + let date = "1985-04-12T23:20:50Z,1985-04-12T23:20:51Z"; + let (e1, date) = iso_8601::read(date).expect("should succeed"); + let (e2, date2) = iso_8601::read(&date[1..]).expect("should succeed"); + assert_eq!(date2, ""); + assert_eq!(date, ",1985-04-12T23:20:51Z"); + let expected = Instant::from_secs_and_nanos(482196050, 0); + assert_eq!(e1, expected); + let expected = Instant::from_secs_and_nanos(482196051, 0); + assert_eq!(e2, expected); + } + proptest! { #![proptest_config(ProptestConfig::with_cases(10000))] @@ -365,4 +411,11 @@ pub mod iso_8601 { utc_date.timestamp_subsec_nanos(), )) } + + /// Read 1 ISO8601 date from &str and return the remaining str + pub fn read(s: &str) -> Result<(Instant, &str), DateParseError> { + let delim = s.find('Z').map(|idx| idx + 1).unwrap_or_else(|| s.len()); + let (head, rest) = s.split_at(delim); + Ok((parse(dbg!(head))?, &rest)) + } } diff --git a/rust-runtime/smithy-types/src/instant/mod.rs b/rust-runtime/smithy-types/src/instant/mod.rs index 4dbb17e53..079dcd0da 100644 --- a/rust-runtime/smithy-types/src/instant/mod.rs +++ b/rust-runtime/smithy-types/src/instant/mod.rs @@ -66,6 +66,28 @@ impl Instant { } } + /// Read 1 date of `format` from `s`, expecting either `delim` or EOF + /// + /// Enable parsing multiple dates from the same string + pub fn read(s: &str, format: Format, delim: char) -> Result<(Self, &str), DateParseError> { + let (inst, next) = match format { + Format::DateTime => format::iso_8601::read(s)?, + Format::HttpDate => format::http_date::read(s)?, + Format::EpochSeconds => { + let split_point = s.find(delim).unwrap_or_else(|| s.len()); + let (s, rest) = s.split_at(split_point); + (Self::from_str(s, format)?, rest) + } + }; + if next.is_empty() { + Ok((inst, next)) + } else if next.starts_with(delim) { + Ok((inst, &next[1..])) + } else { + Err(DateParseError::Invalid("didn't find expected delimiter")) + } + } + fn to_chrono(&self) -> DateTime { DateTime::::from_utc( NaiveDateTime::from_timestamp(self.seconds, self.subsecond_nanos), @@ -113,11 +135,13 @@ impl Instant { } } +#[derive(Clone, Copy, Eq, PartialEq)] pub enum Format { DateTime, HttpDate, EpochSeconds, } + #[cfg(test)] mod test { use crate::instant::Format; @@ -141,4 +165,32 @@ mod test { "Mon, 16 Dec 2019 23:48:18.520 GMT" ); } + + #[test] + fn test_read_single_http_date() { + let s = "Mon, 16 Dec 2019 23:48:18 GMT"; + let (_, next) = Instant::read(s, Format::HttpDate, ',').expect("valid"); + assert_eq!(next, ""); + } + + #[test] + fn test_read_single_float() { + let s = "1576540098.52"; + let (_, next) = Instant::read(s, Format::EpochSeconds, ',').expect("valid"); + assert_eq!(next, ""); + } + + #[test] + fn test_read_many_float() { + let s = "1576540098.52,1576540098.53"; + let (_, next) = Instant::read(s, Format::EpochSeconds, ',').expect("valid"); + assert_eq!(next, "1576540098.53"); + } + + #[test] + fn test_ready_many_http_date() { + let s = "Mon, 16 Dec 2019 23:48:18 GMT,Tue, 17 Dec 2019 23:48:18 GMT"; + let (_, next) = Instant::read(s, Format::HttpDate, ',').expect("valid"); + assert_eq!(next, "Tue, 17 Dec 2019 23:48:18 GMT"); + } } -- GitLab