diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt index 5a1100a364ffc1687d2b40a058b019f62bbba8a3..07d9b9e271bd2e56e0dcbf18dae2791da988554f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/CargoDependency.kt @@ -90,6 +90,7 @@ class InlineDependency( // Stub config implementation as a placeholder before one can be generated dynamically fun config() = forRustFile("config", "config", "config.rs", CargoDependency.Rand) + fun blobSerde(runtimeConfig: RuntimeConfig) = forRustFile("blob_serde", "blob_serde", "blob_serde.rs", CargoDependency.Serde, CargoDependency.SmithyHttp(runtimeConfig)) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index e25943f4705f6b84c86cb4e992beb21fd5709c7e..b637ef11eb939b88c1198dda8c5b577f950d5ab5 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -137,6 +137,8 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n val InstantEpoch = RuntimeType("instant_epoch", InlineDependency.instantEpoch(), "crate") val InstantHttpDate = RuntimeType("instant_httpdate", InlineDependency.instantHttpDate(), "crate") val Instant8601 = RuntimeType("instant_8601", InlineDependency.instant8601(), "crate") + + fun BlobSerde(runtimeConfig: RuntimeConfig) = RuntimeType("blob_serde", InlineDependency.blobSerde(runtimeConfig), "crate") val Config = RuntimeType("config", InlineDependency.config(), "crate") fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt index 6edfdd86f1a02be75863b87aae8cb88effcb1fc8..c5f9df13f9716f39441e25bc240c4044c8d6b333 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/JsonSerializerSymbolProvider.kt @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.rustTemplate import software.amazon.smithy.rust.codegen.lang.stripOuter +import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolMetadataProvider @@ -129,6 +130,7 @@ class SerializerBuilder( private val handWrittenSerializers: Map Unit> = mapOf( "stdoptionoptionblob_ser" to { writer -> + writer.rustBlock("match $inp") { write( "Some(blob) => $ser.serialize_str(&#T(blob.as_ref())),", @@ -143,7 +145,6 @@ class SerializerBuilder( RuntimeType.Base64Encode(runtimeConfig) ) }, - "stdoptionoptioninstant_http_date_ser" to { writer -> val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, TimestampFormatTrait.Format.HTTP_DATE) writer.rustBlock("match $inp") { @@ -181,79 +182,67 @@ class SerializerBuilder( } ) - // TODO: this whole thing needs to be overhauled to be composable - private val handWrittenDeserializers: Map Unit> = mapOf( - "stdoptionoptioninstant_epoch_seconds_deser" to { writer -> - writer.write("use #T;", RuntimeType.Deserialize) - writer.rust( - """ - Ok(Option::<#T::InstantEpoch>::deserialize(_deser)?.map(|i|i.0)) - """, - RuntimeType.InstantEpoch - ) - }, - "blob_deser" to { writer -> - writer.rustTemplate( - """ - use #{deserialize}; - use #{de}::Error; - let data = <&str>::deserialize(_deser)?; - #{base64_decode}(data) - .map(Blob::new) - .map_err(|_|D::Error::invalid_value(#{de}::Unexpected::Str(data), &"valid base64")) - - """, - "deserialize" to RuntimeType.Deserialize, - "de" to RuntimeType.Serde("de"), - "base64_decode" to RuntimeType.Base64Decode(runtimeConfig) - ) - }, - "stdoptionoptionblob_deser" to { writer -> - writer.rustTemplate( - """ - use #{deserialize}; - use #{de}::Error; - Option::<&str>::deserialize(_deser)?.map(|data| { - #{base64_decode}(data) - .map(Blob::new) - .map_err(|_|D::Error::invalid_value(#{de}::Unexpected::Str(data), &"valid base64")) - }).transpose() + /** + * Generate a deserializer for the given type dynamically, eg: + * ```rust + * use ::serde::Deserialize; + * Ok( + * Option::::deserialize(_deser)? + * .map(|el| el.0) + * ) + * ``` + * + * It utilizes a newtype that defines the given serialization to access the serde serializer + * then performs any necessary mapping / unmapping. This has a slight disadvantage in that + * that wrapping structures like `Vec` may be allocated twice—I think we should be able to avoid + * this eventually however. + */ + private fun RustWriter.deserializer(t: RustType, memberShape: MemberShape) { + write("use #T;", RuntimeType.Deserialize) + withBlock("Ok(", ")") { + writeSerdeType(t, memberShape) + write("::deserialize(_deser)?") + unrollDeser(t) + } + } - """, - "deserialize" to RuntimeType.Deserialize, - "de" to RuntimeType.Serde("de"), - "base64_decode" to RuntimeType.Base64Decode(runtimeConfig) - ) - }, - "instant_epoch_seconds_deser" to { writer -> - writer.write("use #T;", RuntimeType.Deserialize) - writer.rust( - """ - let ts = f64::deserialize(_deser)?; - Ok(Instant::from_fractional_seconds(ts.floor() as i64, ts - ts.floor())) - """ - ) - }, - "stdoptionoptioninstant_http_date_deser" to { writer -> - writer.write("use #T;", RuntimeType.Deserialize) - writer.rust( - """ - Ok(Option::<#T::InstantHttpDate>::deserialize(_deser)?.map(|i|i.0)) - """, - RuntimeType.InstantHttpDate - ) - }, - "stdoptionoptioninstant_date_time_deser" to { writer -> - writer.write("use #T;", RuntimeType.Deserialize) - writer.rust( - """ - Ok(Option::<#T::InstantIso8601>::deserialize(_deser)?.map(|i|i.0)) - """, - RuntimeType.Instant8601 - ) + private fun RustWriter.unrollDeser(realType: RustType) { + when (realType) { + is RustType.Vec -> withBlock(".into_iter().map(|el|el", ").collect()") { + unrollDeser(realType.member) + } + is RustType.Option -> withBlock(".map(|el|el", ")") { + unrollDeser(realType.value) + } + else -> write(".0") } + } - ) + private fun RustWriter.writeSerdeType(realType: RustType, memberShape: MemberShape) { + when (realType) { + is RustType.Option -> { + withBlock("Option::<", ">") { + writeSerdeType(realType.value, memberShape) + } + } + is RustType.Vec -> { + withBlock("Vec::<", ">") { + writeSerdeType(realType.member, memberShape) + } + } + instant -> { + val format = tsFormat(memberShape) + when (format) { + TimestampFormatTrait.Format.DATE_TIME -> write("#T::InstantIso8601", RuntimeType.Instant8601) + TimestampFormatTrait.Format.EPOCH_SECONDS -> write("#T::InstantEpoch", RuntimeType.InstantEpoch) + TimestampFormatTrait.Format.HTTP_DATE -> write("#T::InstantHttpDate", RuntimeType.InstantHttpDate) + else -> write("todo!() /* unknown timestamp format */") + } + } + blob -> write("#T::BlobDeser", RuntimeType.BlobSerde(runtimeConfig)) + else -> writeWithNoFormatting("todo!() /* not sure what type to use for $realType */") + } + } /** correct argument type for the serde custom serializer */ private fun serializerType(symbol: Symbol): Symbol { @@ -337,7 +326,11 @@ class SerializerBuilder( val fnName = serializerName(rustType, memberShape, "deser") return RuntimeType.forInlineFun(fnName, "serde_util") { writer -> deserializeFn(writer, fnName, symbol) { - handWrittenDeserializers[fnName]?.also { it(this) } ?: write("todo!()") + if (rustType.contains(document)) { + write("todo!()") + } else { + deserializer(rustType, memberShape) + } } } } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/SerializerBuilderTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/SerializerBuilderTest.kt index b0d693c2fa4d3baa1ac6f9c081c4305d9197071a..c640f5ffb448b1b22d79b6c4d763078cbfbc9140 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/SerializerBuilderTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/SerializerBuilderTest.kt @@ -8,8 +8,11 @@ package software.amazon.smithy.rust.codegen.smithy.protocols import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.lang.rust import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.testutil.TestWorkspace import software.amazon.smithy.rust.testutil.asSmithyModel +import software.amazon.smithy.rust.testutil.compileAndTest import software.amazon.smithy.rust.testutil.testSymbolProvider internal class SerializerBuilderTest { @@ -31,4 +34,27 @@ internal class SerializerBuilderTest { serializerBuilder.deserializerFor(model.lookup("test#S\$b"))!!.name shouldBe "stdoptionoptionblob_deser" serializerBuilder.deserializerFor(model.lookup("test#S\$s")) shouldBe null } + + @Test + fun `generate deserializers that compile`() { + val serializerBuilder = SerializerBuilder(provider, model, TimestampFormatTrait.Format.EPOCH_SECONDS) + val timestamp = serializerBuilder.deserializerFor(model.lookup("test#S\$ts"))!! + val blob = serializerBuilder.deserializerFor(model.lookup("test#S\$b"))!! + val writer = TestWorkspace.testProject(provider) + writer.useFileWriter("src/lib.rs", "crate::lib") { + it.rust( + """ + fn foo() { + // commented out so that we generate the import & inject the serializer + // but I don't want to deal with getting the argument to compile + // let _ = #T(); + // let _ = #T(); + } + """, + timestamp, blob + ) + } + println("file:///${writer.baseDir}/src/serde_util.rs") + writer.compileAndTest() + } } diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index 6fb2fe1d33a4128d6cbf785dfdef574e801e96e3..c832d2f23c2835e994fb54be44e196b687873c75 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -13,6 +13,7 @@ are to allow this crate to be compilable and testable in isolation, no client co "serde_json" = "1" "http" = "0.2.1" "smithy-types" = { version = "0.0.1", path = "../smithy-types" } +"smithy-http" = { version = "0.0.1", path = "../smithy-http"} "rand" = "0.7" [dev-dependencies] diff --git a/rust-runtime/inlineable/src/blob_serde.rs b/rust-runtime/inlineable/src/blob_serde.rs new file mode 100644 index 0000000000000000000000000000000000000000..f66e2e5cd13892706c9995619150cda3c7419d2c --- /dev/null +++ b/rust-runtime/inlineable/src/blob_serde.rs @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use serde::de::{Error, Unexpected}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use smithy_http::base64; +use smithy_types::Blob; + +struct BlobSer<'a>(pub &'a Blob); + +impl Serialize for BlobSer<'_> { + fn serialize(&self, serializer: S) -> Result<::Ok, ::Error> + where + S: Serializer, + { + serializer.serialize_str(base64::encode(self.0.as_ref()).as_str()) + } +} + +pub struct BlobDeser(pub Blob); + +impl<'de> Deserialize<'de> for BlobDeser { + fn deserialize(deserializer: D) -> Result>::Error> + where + D: Deserializer<'de>, + { + let data = <&str>::deserialize(deserializer)?; + let bytes = base64::decode(data) + .map_err(|_| D::Error::invalid_value(Unexpected::Str(data), &"valid base64"))?; + Ok(BlobDeser(Blob::new(bytes))) + } +} diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 8715bec075d42d83730a2822abb2f034feb01c1b..0d47ecc9d35c469a784e4ef2a1306e138d462d30 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ +mod blob_serde; #[allow(dead_code)] mod config; mod doc_json;