Unverified Commit 02d0db5d authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generate deserializers dynamically for nested types (#95)

This code replaces the hand-written deserializers that were used to serialize types like Instant & Blob with autogenerated code. A follow-on diff will do the same thing for serializers.

Once sufficiently developed (and if desired) this would allow us to completely remove our use of Serde-derived serializers with serializers & deserializers generated at codegen time.
parent 5c0bfa4f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -90,6 +90,7 @@ class InlineDependency(

        // Stub config implementation as a placeholder before one can be generated dynamically
        fun config() = forRustFile("config", "config", "config.rs", CargoDependency.Rand)
        fun blobSerde(runtimeConfig: RuntimeConfig) = forRustFile("blob_serde", "blob_serde", "blob_serde.rs", CargoDependency.Serde, CargoDependency.SmithyHttp(runtimeConfig))
    }
}

+2 −0
Original line number Diff line number Diff line
@@ -137,6 +137,8 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
        val InstantEpoch = RuntimeType("instant_epoch", InlineDependency.instantEpoch(), "crate")
        val InstantHttpDate = RuntimeType("instant_httpdate", InlineDependency.instantHttpDate(), "crate")
        val Instant8601 = RuntimeType("instant_8601", InlineDependency.instant8601(), "crate")

        fun BlobSerde(runtimeConfig: RuntimeConfig) = RuntimeType("blob_serde", InlineDependency.blobSerde(runtimeConfig), "crate")
        val Config = RuntimeType("config", InlineDependency.config(), "crate")

        fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType(
+65 −72
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.rustTemplate
import software.amazon.smithy.rust.codegen.lang.stripOuter
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.SymbolMetadataProvider
@@ -129,6 +130,7 @@ class SerializerBuilder(

    private val handWrittenSerializers: Map<String, (RustWriter) -> Unit> = mapOf(
        "stdoptionoptionblob_ser" to { writer ->

            writer.rustBlock("match $inp") {
                write(
                    "Some(blob) => $ser.serialize_str(&#T(blob.as_ref())),",
@@ -143,7 +145,6 @@ class SerializerBuilder(
                RuntimeType.Base64Encode(runtimeConfig)
            )
        },

        "stdoptionoptioninstant_http_date_ser" to { writer ->
            val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, TimestampFormatTrait.Format.HTTP_DATE)
            writer.rustBlock("match $inp") {
@@ -181,79 +182,67 @@ class SerializerBuilder(
        }
    )

    // TODO: this whole thing needs to be overhauled to be composable
    private val handWrittenDeserializers: Map<String, (RustWriter) -> Unit> = mapOf(
        "stdoptionoptioninstant_epoch_seconds_deser" to { writer ->
            writer.write("use #T;", RuntimeType.Deserialize)
            writer.rust(
                """
                Ok(Option::<#T::InstantEpoch>::deserialize(_deser)?.map(|i|i.0))
            """,
                RuntimeType.InstantEpoch
            )
        },
        "blob_deser" to { writer ->
            writer.rustTemplate(
                """
                use #{deserialize};
                use #{de}::Error;
                let data = <&str>::deserialize(_deser)?;
                #{base64_decode}(data)
                    .map(Blob::new)
                    .map_err(|_|D::Error::invalid_value(#{de}::Unexpected::Str(data), &"valid base64"))

            """,
                "deserialize" to RuntimeType.Deserialize,
                "de" to RuntimeType.Serde("de"),
                "base64_decode" to RuntimeType.Base64Decode(runtimeConfig)
            )
        },
        "stdoptionoptionblob_deser" to { writer ->
            writer.rustTemplate(
                """
                use #{deserialize};
                use #{de}::Error;
                Option::<&str>::deserialize(_deser)?.map(|data| {
                    #{base64_decode}(data)
                        .map(Blob::new)
                        .map_err(|_|D::Error::invalid_value(#{de}::Unexpected::Str(data), &"valid base64"))
                }).transpose()
    /**
     * Generate a deserializer for the given type dynamically, eg:
     * ```rust
     *  use ::serde::Deserialize;
     *  Ok(
     *      Option::<crate::instant_epoch::InstantEpoch>::deserialize(_deser)?
     *          .map(|el| el.0)
     *  )
     * ```
     *
     * It utilizes a newtype that defines the given serialization to access the serde serializer
     * then performs any necessary mapping / unmapping. This has a slight disadvantage in that
     * that wrapping structures like `Vec` may be allocated twice—I think we should be able to avoid
     * this eventually however.
     */
    private fun RustWriter.deserializer(t: RustType, memberShape: MemberShape) {
        write("use #T;", RuntimeType.Deserialize)
        withBlock("Ok(", ")") {
            writeSerdeType(t, memberShape)
            write("::deserialize(_deser)?")
            unrollDeser(t)
        }
    }

            """,
                "deserialize" to RuntimeType.Deserialize,
                "de" to RuntimeType.Serde("de"),
                "base64_decode" to RuntimeType.Base64Decode(runtimeConfig)
            )
        },
        "instant_epoch_seconds_deser" to { writer ->
            writer.write("use #T;", RuntimeType.Deserialize)
            writer.rust(
                """
                let ts = f64::deserialize(_deser)?;
                Ok(Instant::from_fractional_seconds(ts.floor() as i64, ts - ts.floor()))
            """
            )
        },
        "stdoptionoptioninstant_http_date_deser" to { writer ->
            writer.write("use #T;", RuntimeType.Deserialize)
            writer.rust(
                """
                Ok(Option::<#T::InstantHttpDate>::deserialize(_deser)?.map(|i|i.0))
            """,
                RuntimeType.InstantHttpDate
            )
        },
        "stdoptionoptioninstant_date_time_deser" to { writer ->
            writer.write("use #T;", RuntimeType.Deserialize)
            writer.rust(
                """
                Ok(Option::<#T::InstantIso8601>::deserialize(_deser)?.map(|i|i.0))
            """,
                RuntimeType.Instant8601
            )
    private fun RustWriter.unrollDeser(realType: RustType) {
        when (realType) {
            is RustType.Vec -> withBlock(".into_iter().map(|el|el", ").collect()") {
                unrollDeser(realType.member)
            }
            is RustType.Option -> withBlock(".map(|el|el", ")") {
                unrollDeser(realType.value)
            }
            else -> write(".0")
        }
    }

    )
    private fun RustWriter.writeSerdeType(realType: RustType, memberShape: MemberShape) {
        when (realType) {
            is RustType.Option -> {
                withBlock("Option::<", ">") {
                    writeSerdeType(realType.value, memberShape)
                }
            }
            is RustType.Vec -> {
                withBlock("Vec::<", ">") {
                    writeSerdeType(realType.member, memberShape)
                }
            }
            instant -> {
                val format = tsFormat(memberShape)
                when (format) {
                    TimestampFormatTrait.Format.DATE_TIME -> write("#T::InstantIso8601", RuntimeType.Instant8601)
                    TimestampFormatTrait.Format.EPOCH_SECONDS -> write("#T::InstantEpoch", RuntimeType.InstantEpoch)
                    TimestampFormatTrait.Format.HTTP_DATE -> write("#T::InstantHttpDate", RuntimeType.InstantHttpDate)
                    else -> write("todo!() /* unknown timestamp format */")
                }
            }
            blob -> write("#T::BlobDeser", RuntimeType.BlobSerde(runtimeConfig))
            else -> writeWithNoFormatting("todo!() /* not sure what type to use for $realType */")
        }
    }

    /** correct argument type for the serde custom serializer */
    private fun serializerType(symbol: Symbol): Symbol {
@@ -337,7 +326,11 @@ class SerializerBuilder(
        val fnName = serializerName(rustType, memberShape, "deser")
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            deserializeFn(writer, fnName, symbol) {
                handWrittenDeserializers[fnName]?.also { it(this) } ?: write("todo!()")
                if (rustType.contains(document)) {
                    write("todo!()")
                } else {
                    deserializer(rustType, memberShape)
                }
            }
        }
    }
+26 −0
Original line number Diff line number Diff line
@@ -8,8 +8,11 @@ package software.amazon.smithy.rust.codegen.smithy.protocols
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.testutil.TestWorkspace
import software.amazon.smithy.rust.testutil.asSmithyModel
import software.amazon.smithy.rust.testutil.compileAndTest
import software.amazon.smithy.rust.testutil.testSymbolProvider

internal class SerializerBuilderTest {
@@ -31,4 +34,27 @@ internal class SerializerBuilderTest {
        serializerBuilder.deserializerFor(model.lookup("test#S\$b"))!!.name shouldBe "stdoptionoptionblob_deser"
        serializerBuilder.deserializerFor(model.lookup("test#S\$s")) shouldBe null
    }

    @Test
    fun `generate deserializers that compile`() {
        val serializerBuilder = SerializerBuilder(provider, model, TimestampFormatTrait.Format.EPOCH_SECONDS)
        val timestamp = serializerBuilder.deserializerFor(model.lookup("test#S\$ts"))!!
        val blob = serializerBuilder.deserializerFor(model.lookup("test#S\$b"))!!
        val writer = TestWorkspace.testProject(provider)
        writer.useFileWriter("src/lib.rs", "crate::lib") {
            it.rust(
                """
                fn foo() {
                    // commented out so that we generate the import & inject the serializer
                    // but I don't want to deal with getting the argument to compile
                    // let _ = #T();
                    // let _ = #T();
                }
            """,
                timestamp, blob
            )
        }
        println("file:///${writer.baseDir}/src/serde_util.rs")
        writer.compileAndTest()
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ are to allow this crate to be compilable and testable in isolation, no client co
"serde_json" = "1"
"http" = "0.2.1"
"smithy-types" = { version = "0.0.1", path = "../smithy-types" }
"smithy-http" = { version = "0.0.1", path = "../smithy-http"}
"rand" = "0.7"

[dev-dependencies]
Loading