Unverified Commit d89d58ca authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for deserializing bodies (except special shapes) (#57)

* Add support for deserializing bodies

* Don't mark top level shapes as serializable

* Use positional formatters to shorten code in a few spots
parent efdca448
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -162,7 +162,7 @@ class RustWriter private constructor(
     *
     * Callers must take care to use [this] when writing to ensure code is written to the right place:
     * ```kotlin
     * val writer = RustWriter.forModule("models")
     * val writer = RustWriter.forModule("model")
     * writer.withModule("nested") {
     *   Generator(...).render(this) // GOOD
     *   Generator(...).render(writer) // WRONG!
+16 −4
Original line number Diff line number Diff line
@@ -78,10 +78,18 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
            RuntimeType(func, CargoDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::query")

        fun Base64Encode(runtimeConfig: RuntimeConfig): RuntimeType =
            RuntimeType("encode", CargoDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64")
            RuntimeType(
                "encode",
                CargoDependency.SmithyHttp(runtimeConfig),
                "${runtimeConfig.cratePrefix}_http::base64"
            )

        fun Base64Decode(runtimeConfig: RuntimeConfig): RuntimeType =
            RuntimeType("decode", CargoDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64")
            RuntimeType(
                "decode",
                CargoDependency.SmithyHttp(runtimeConfig),
                "${runtimeConfig.cratePrefix}_http::base64"
            )

        fun TimestampFormat(runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format): RuntimeType {
            val timestampFormat = when (format) {
@@ -102,13 +110,17 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
                func, CargoDependency.ProtocolTestHelpers(runtimeConfig), "protocol_test_helpers"
            )

        fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http")
        fun Http(path: String): RuntimeType =
            RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http")

        val HttpRequestBuilder = Http("request::Builder")

        val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde")
        val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde")
        val Serializer = RuntimeType("Serializer", CargoDependency.Serde, namespace = "serde")
        fun SerdeJson(path: String) = RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json")
        val Deserializer = RuntimeType("Deserializer", CargoDependency.Serde, namespace = "serde")
        fun SerdeJson(path: String) =
            RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

        fun forInlineFun(name: String, module: String, func: (RustWriter) -> Unit) = RuntimeType(
            name = name,
+0 −1
Original line number Diff line number Diff line
@@ -51,7 +51,6 @@ class HttpProtocolTestGenerator(
        // Misc:
        "RestJsonQueryIdempotencyTokenAutoFill", // https://github.com/awslabs/smithy-rs/issues/34
        "RestJsonHttpPrefixHeadersArePresent" // https://github.com/awslabs/smithy-rs/issues/35

    )
    private val inputShape = operationShape.inputShape(protocolConfig.model)
    fun render() {
+2 −4
Original line number Diff line number Diff line
@@ -80,8 +80,7 @@ class HttpTraitBindingGenerator(
        val hasHeaders = addHeaders(implBlockWriter)
        val hasQuery = uriQuery(implBlockWriter)
        implBlockWriter.rustBlock(
            "fn update_http_builder(&self, builder: \$T) -> \$T",
            RuntimeType.HttpRequestBuilder,
            "fn update_http_builder(&self, builder: \$1T) -> \$1T",
            RuntimeType.HttpRequestBuilder
        ) {
            write("let mut uri = String::new();")
@@ -108,8 +107,7 @@ class HttpTraitBindingGenerator(
            return false
        }
        writer.rustBlock(
            "fn add_headers(&self, mut builder: \$T) -> \$T",
            RuntimeType.HttpRequestBuilder,
            "fn add_headers(&self, mut builder: \$1T) -> \$1T",
            RuntimeType.HttpRequestBuilder
        ) {
            headers.forEach { httpBinding ->
+139 −130
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.protocols

import org.intellij.lang.annotations.Language
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
@@ -15,6 +14,7 @@ import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.lang.Attribute
import software.amazon.smithy.rust.codegen.lang.Custom
import software.amazon.smithy.rust.codegen.lang.RustMetadata
import software.amazon.smithy.rust.codegen.lang.RustType
@@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.lang.contains
import software.amazon.smithy.rust.codegen.lang.render
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.stripOuter
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.SymbolMetadataProvider
@@ -32,6 +31,8 @@ import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.dq

/**
@@ -40,8 +41,7 @@ import software.amazon.smithy.rust.codegen.util.dq
class JsonSerializerSymbolProvider(
    private val model: Model,
    private val base: RustSymbolProvider,
    private val defaultTimestampFormat: TimestampFormatTrait.Format

    defaultTimestampFormat: TimestampFormatTrait.Format
) :
    SymbolMetadataProvider(base) {

@@ -50,143 +50,73 @@ class JsonSerializerSymbolProvider(
    private fun MemberShape.serializedName() =
        this.getTrait(JsonNameTrait::class.java).map { it.value }.orElse(this.memberName)

    private val httpIndex = HttpBindingIndex.of(model)
    private val serializerBuilder = SerializerBuilder(base.config().runtimeConfig)
    private val serializerBuilder = SerializerBuilder(base, model, defaultTimestampFormat)
    override fun memberMeta(memberShape: MemberShape): RustMetadata {
        val currentMeta = base.toSymbol(memberShape).expectRustMetadata()
        val serdeConfig = serdeRequired(model.expectShape(memberShape.container))
        val attribs = mutableListOf<Attribute>()
        if (serdeConfig.serialize || serdeConfig.deserialize) {
            attribs.add(Custom("serde(rename = ${memberShape.serializedName().dq()})"))
        }
        if (serdeConfig.serialize) {
            val skipIfNone =
            if (base.toSymbol(memberShape).rustType().stripOuter<RustType.Reference>() is RustType.Option) {
                    listOf(Custom("serde(skip_serializing_if = \"Option::is_none\")"))
                } else {
                    listOf()
                }
            val renameAttribute = Custom("serde(rename = ${memberShape.serializedName().dq()})")
            val serializer = serializerFor(memberShape)
            val serdeAttribute = serializer?.let {
                listOf(Custom("serde(serialize_with = ${serializer.fullyQualifiedName().dq()})", listOf(it)))
            } ?: listOf()
            return currentMeta.copy(additionalAttributes = currentMeta.additionalAttributes + renameAttribute + serdeAttribute + skipIfNone)
        } else {
            return currentMeta
                attribs.add(Custom("serde(skip_serializing_if = \"Option::is_none\")"))
            }
            serializerBuilder.serializerFor(memberShape)?.also {
                attribs.add(Custom("serde(serialize_with = ${it.fullyQualifiedName().dq()})", listOf(it)))
            }
        }
        if (serdeConfig.deserialize) {
            serializerBuilder.deserializerFor(memberShape)?.also {
                attribs.add(Custom("serde(deserialize_with = ${it.fullyQualifiedName().dq()})", listOf(it)))
            }
        }
        return currentMeta.copy(additionalAttributes = currentMeta.additionalAttributes + attribs)
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata = containerMeta(structureShape)
    override fun unionMeta(unionShape: UnionShape): RustMetadata = containerMeta(unionShape)
    override fun enumMeta(stringShape: StringShape): RustMetadata = containerMeta(stringShape)

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        val currentMeta = base.toSymbol(structureShape).expectRustMetadata()
        val requiredSerde = serdeRequired(structureShape)
    private fun containerMeta(container: Shape): RustMetadata {
        val currentMeta = base.toSymbol(container).expectRustMetadata()
        val requiredSerde = serdeRequired(container)
        return currentMeta
            .letIf(requiredSerde.serialize) { it.withDerive(RuntimeType.Serialize) }
        // TODO: generate deserializers
        // .letIf(requiredSerde.deserialize) { it.withDerive(RuntimeType.Deserialize) }
            .letIf(requiredSerde.deserialize) { it.withDerive(RuntimeType.Deserialize) }
    }

    private fun serdeRequired(shape: Shape): SerdeConfig {
        return when {
            shape.hasTrait(InputBodyTrait::class.java) -> SerdeConfig(serialize = true, deserialize = false)
            shape.hasTrait(OutputBodyTrait::class.java) -> SerdeConfig(serialize = false, deserialize = true)
            else -> SerdeConfig(serialize = true, deserialize = true)
        }
    }

    override fun unionMeta(unionShape: UnionShape): RustMetadata {
        val currentMeta = base.toSymbol(unionShape).expectRustMetadata()
        return currentMeta.withDerive(RuntimeType.Serialize)
    }

    override fun enumMeta(stringShape: StringShape): RustMetadata {
        val currentMeta = base.toSymbol(stringShape).expectRustMetadata()
        return currentMeta.withDerive(RuntimeType.Serialize)
    }

    private fun serializerFor(memberShape: MemberShape): RuntimeType? {
        val rustType = base.toSymbol(memberShape).rustType()
        val instant = base.toSymbol(TimestampShape.builder().id("dummy#ts").build()).rustType()
        val blob = base.toSymbol(BlobShape.builder().id("dummy#ts").build()).rustType()
        val document = base.toSymbol(DocumentShape.builder().id("dummy#ts").build()).rustType()
        if (!(rustType.contains(blob) || rustType.contains(instant) || rustType.contains(document))) {
            return null
        }
        val targetType = rustType.stripOuter<RustType.Reference>()
        val typeAsFunctionName = targetType.render().filter { it.isLetterOrDigit() }.toLowerCase()
        return when {
            rustType.contains(instant) -> instantSerializer(memberShape, typeAsFunctionName, targetType)
            rustType.contains(blob) -> blobSerializer(memberShape, typeAsFunctionName, targetType)
            rustType.contains(document) -> documentSerializer(memberShape, typeAsFunctionName, targetType)
            else -> null
        }
    }

    private fun documentSerializer(
        memberShape: MemberShape,
        typeAsFunctionName: String,
        argType: RustType
    ): RuntimeType {
        val symbol = base.toSymbol(memberShape)
        val fnName = typeAsFunctionName
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            serializeFn(writer, fnName, symbol, argType) {
                write("todo!()")
            // The bodies must be serializable. The top level inputs are _not_
            shape.hasTrait(SyntheticInputTrait::class.java) -> SerdeConfig(serialize = false, deserialize = false)
            shape.hasTrait(SyntheticOutputTrait::class.java) -> SerdeConfig(serialize = false, deserialize = false)
            else -> SerdeConfig(serialize = true, deserialize = true)
        }
    }
}

    private fun serializeFn(
        rustWriter: RustWriter,
        functionName: String,
        symbol: Symbol,
        targetType: RustType,
        body: RustWriter.() -> Unit
    ) {
        // Convert Vec<T> to `[T]` when present. This is needed to avoid
        // Clippy complaining (and is also better in general).
        val sliceToVec = when (targetType) {
            is RustType.Vec -> RustType.Slice(targetType.member)
            else -> targetType
        }
        val ref = RustType.Reference(lifetime = null, value = sliceToVec)
        val newSymbol = symbol.toBuilder().rustType(ref).build()
        rustWriter.rustBlock(
            "pub fn $functionName<S>(_inp: \$T, _serializer: S) -> " +
                "Result<<S as \$T>::Ok, <S as \$T>::Error> where S: \$T",
            newSymbol,
            RuntimeType.Serializer,
            RuntimeType.Serializer,
            RuntimeType.Serializer
class SerializerBuilder(
    private val symbolProvider: RustSymbolProvider,
    model: Model,
    private val defaultTimestampFormat: TimestampFormatTrait.Format
) {
            body(this)
        }
    }

    private fun blobSerializer(memberShape: MemberShape, baseTypeName: String, argType: RustType): RuntimeType {
        val symbol = base.toSymbol(memberShape)
        val fnName = baseTypeName
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            serializeFn(writer, fnName, symbol, argType) {
                serializerBuilder.render(this, baseTypeName)
            }
        }
    }

    private fun instantSerializer(memberShape: MemberShape, baseTypeName: String, argType: RustType): RuntimeType {
        val instantFormat =
            httpIndex.determineTimestampFormat(memberShape, HttpBinding.Location.PAYLOAD, defaultTimestampFormat)
        val symbol = base.toSymbol(memberShape)
        val fnName = "${baseTypeName}_${instantFormat.name.replace('-', '_').toLowerCase()}"
        return RuntimeType.forInlineFun(fnName, "serde_util") { rustWriter: RustWriter ->
            serializeFn(rustWriter, fnName, symbol, argType) {
                serializerBuilder.render(this, fnName)
            }
        }
    }
}

class SerializerBuilder(runtimeConfig: RuntimeConfig) {
    private val inp = "_inp"
    private val ser = "_serializer"
    private val HandWrittenSerializers: Map<String, (RustWriter) -> Unit> = mapOf(
        "optionblob" to { writer ->
    private val httpBindingIndex = HttpBindingIndex.of(model)
    private val runtimeConfig = symbolProvider.config().runtimeConfig

    // Small hack to get the Rust type for these problematic shapes
    private val instant = symbolProvider.toSymbol(TimestampShape.builder().id("dummy#ts").build()).rustType()
    private val blob = symbolProvider.toSymbol(BlobShape.builder().id("dummy#blob").build()).rustType()
    private val document = symbolProvider.toSymbol(DocumentShape.builder().id("dummy#doc").build()).rustType()
    private val customShapes = setOf(instant, blob, document)

    private val handWrittenSerializers: Map<String, (RustWriter) -> Unit> = mapOf(
        "optionblob_ser" to { writer ->
            writer.rustBlock("match $inp") {
                write(
                    "Some(blob) => $ser.serialize_str(&\$T(blob.as_ref())),",
@@ -195,14 +125,14 @@ class SerializerBuilder(runtimeConfig: RuntimeConfig) {
                write("None => $ser.serialize_none()")
            }
        },
        "blob" to { writer ->
        "blob_ser" to { writer ->
            writer.write(
                "$ser.serialize_str(&\$T($inp.as_ref()))",
                RuntimeType.Base64Encode(runtimeConfig)
            )
        },

        "optioninstant_http_date" to { writer ->
        "optioninstant_http_date_ser" to { writer ->
            val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, TimestampFormatTrait.Format.HTTP_DATE)
            writer.rustBlock("match $inp") {
                write(
@@ -211,7 +141,7 @@ class SerializerBuilder(runtimeConfig: RuntimeConfig) {
                write("None => _serializer.serialize_none()")
            }
        },
        "optioninstant_date_time" to { writer ->
        "optioninstant_date_time_ser" to { writer ->
            val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, TimestampFormatTrait.Format.DATE_TIME)
            writer.rustBlock("match $inp") {
                write(
@@ -220,22 +150,101 @@ class SerializerBuilder(runtimeConfig: RuntimeConfig) {
                write("None => _serializer.serialize_none()")
            }
        },
        "optioninstant_epoch_seconds" to { writer ->
        "optioninstant_epoch_seconds_ser" to { writer ->
            writer.rustBlock("match $inp") {
                write("Some(ts) => $ser.serialize_some(&ts.epoch_seconds()),")
                write("None => _serializer.serialize_none()")
            }
        },
        "instant_epoch_seconds" to { writer ->
        "instant_epoch_seconds_ser" to { writer ->
            writer.write("$ser.serialize_i64($inp.epoch_seconds())")
        }
    )

    fun render(writer: RustWriter, name: String) =
        HandWrittenSerializers[name]?.also { it(writer) } ?: writer.write("todo!()")
    /** correct argument type for the serde custom serializer */
    private fun serializerType(symbol: Symbol): Symbol {
        val unref = symbol.rustType().stripOuter<RustType.Reference>()

        // Convert `Vec<T>` to `[T]` when present. This is needed to avoid
        // Clippy complaining (and is also better in general).
        val outType = when (unref) {
            is RustType.Vec -> RustType.Slice(unref.member)
            else -> unref
        }
        val referenced = RustType.Reference(value = outType, lifetime = null)
        return symbol.toBuilder().rustType(referenced).build()
    }

@Language("Rust")
private fun String.rust(): String {
    return this
    private fun tsFormat(memberShape: MemberShape) =
        httpBindingIndex.determineTimestampFormat(memberShape, HttpBinding.Location.PAYLOAD, defaultTimestampFormat)

    private fun serializerName(rustType: RustType, memberShape: MemberShape, suffix: String): String {
        val context = when {
            rustType.contains(instant) -> tsFormat(memberShape).name.replace('-', '_').toLowerCase()
            else -> null
        }
        val typeToFnName =
            rustType.stripOuter<RustType.Reference>().render().filter { it.isLetterOrDigit() }.toLowerCase()
        return listOfNotNull(typeToFnName, context, suffix).joinToString("_")
    }

    private fun serializeFn(
        rustWriter: RustWriter,
        functionName: String,
        symbol: Symbol,
        body: RustWriter.() -> Unit
    ) {
        rustWriter.rustBlock(
            "pub fn $functionName<S>(_inp: \$1T, _serializer: S) -> " +
                "Result<<S as \$2T>::Ok, <S as \$2T>::Error> where S: \$2T",
            serializerType(symbol),
            RuntimeType.Serializer
        ) {
            body(this)
        }
    }

    private fun deserializeFn(
        rustWriter: RustWriter,
        functionName: String,
        symbol: Symbol,
        body: RustWriter.() -> Unit
    ) {
        rustWriter.rustBlock(
            "pub fn $functionName<'de, D>(_deser: D) -> Result<\$T, D::Error> where D: \$T<'de>",
            symbol,
            RuntimeType.Deserializer
        ) {
            body(this)
        }
    }

    fun serializerFor(memberShape: MemberShape): RuntimeType? {
        val symbol = symbolProvider.toSymbol(memberShape)
        val rustType = symbol.rustType()
        if (customShapes.none { rustType.contains(it) }) {
            return null
        }
        val fnName = serializerName(rustType, memberShape, "ser")
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            serializeFn(writer, fnName, symbol) {
                handWrittenSerializers[fnName]?.also { it(this) } ?: write("todo!()")
            }
        }
    }

    fun deserializerFor(memberShape: MemberShape): RuntimeType? {
        val symbol = symbolProvider.toSymbol(memberShape)
        val rustType = symbol.rustType()
        if (customShapes.none { rustType.contains(it) }) {
            return null
        }
        val fnName = serializerName(rustType, memberShape, "deser")
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            deserializeFn(writer, fnName, symbol) {
                // TODO: implement deserializers
                write("todo!()")
            }
        }
    }
}
Loading