Unverified Commit 32b55125 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

wip (#51)

parent be16c2f2
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -106,6 +106,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        val HttpRequestBuilder = Http("request::Builder")

        val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde")
        val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde")
        val Serializer = RuntimeType("Serializer", CargoDependency.Serde, namespace = "serde")
        fun SerdeJson(path: String) = RuntimeType(path, dependency = CargoDependency.SerdeJson, namespace = "serde_json")

+10 −2
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait

class ServiceGenerator(
    private val writers: CodegenWriterDelegator<RustWriter>,
@@ -32,14 +33,21 @@ class ServiceGenerator(

    private fun renderBodies() {
        val operations = index.getContainedOperations(config.serviceShape)
        val bodies = operations.map { config.model.expectShape(it.input.get()) }.map {
        val inputBodies = operations.map { config.model.expectShape(it.input.get()) }.map {
            it.expectTrait(SyntheticInputTrait::class.java)
        }.mapNotNull { // mapNotNull is flatMap but for null `map { it }.filter { it != null }`
            it.body
        }.map { // Lookup the Body structure by its id
            config.model.expectShape(it, StructureShape::class.java)
        }
        bodies.map { body ->
        val outputBodies = operations.map { config.model.expectShape(it.output.get()) }.map {
            it.expectTrait(SyntheticOutputTrait::class.java)
        }.mapNotNull { // mapNotNull is flatMap but for null `map { it }.filter { it != null }`
            it.body
        }.map { // Lookup the Body structure by its id
            config.model.expectShape(it, StructureShape::class.java)
        }
        (inputBodies + outputBodies).map { body ->
            // The body symbol controls its location, usually in the serializer module
            writers.useShapeWriter(body) { writer ->
                with(config) {
+29 −7
Original line number Diff line number Diff line
@@ -28,7 +28,9 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.locatedIn
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier

sealed class AwsJsonVersion {
    abstract val value: String
@@ -47,13 +49,20 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat
        protocolConfig: ProtocolConfig
    ): BasicAwsJsonGenerator = BasicAwsJsonGenerator(protocolConfig, version)

    override fun transformModel(model: Model): Model {
        // For AwsJson10, the body matches 1:1 with the input
        return OperationNormalizer().transformModel(model) { inputShape ->
            if (inputShape != null && inputShape.members().isEmpty()) {
    private val shapeIfHasMembers: StructureModifier = { shape: StructureShape? ->
        if (shape?.members().isNullOrEmpty()) {
            null
            } else inputShape
        } else {
            shape
        }
    }

    override fun transformModel(model: Model): Model {
        // For AwsJson10, the body matches 1:1 with the input
        return OperationNormalizer(model).transformModel(
            inputBodyFactory = shapeIfHasMembers,
            outputBodyFactory = shapeIfHasMembers
        )
    }

    override fun symbolProvider(model: Model, base: RustSymbolProvider): SymbolProvider {
@@ -70,14 +79,27 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat
/**
 * SyntheticBodySymbolProvider makes two modifications:
 * 1. Body shapes are moved to `serializer.rs`
 * 2. Body shapes take a reference to all of their members.
 * 2. Body shapes take a reference to all of their members:
 * If the base structure was:
 * ```rust
 * struct {
 *   field: Option<u64>
 * }
 * ```
 * The body will generate:
 * ```rust
 * struct<'a> {
 *   field: &'a Option<u64>
 * }
 *
 * This enables the creation of a body from a reference to an input without cloning.
 */
class SyntheticBodySymbolProvider(private val model: Model, private val base: RustSymbolProvider) :
    WrappingSymbolProvider(base) {
    override fun toSymbol(shape: Shape): Symbol {
        val initialSymbol = base.toSymbol(shape)
        val override = when (shape) {
            is StructureShape -> if (shape.hasTrait(InputBodyTrait::class.java)) {
            is StructureShape -> if (shape.hasTrait(InputBodyTrait::class.java) || shape.hasTrait(OutputBodyTrait::class.java)) {
                initialSymbol.toBuilder().locatedIn(Serializers).build()
            } else null
            is MemberShape -> {
+6 −2
Original line number Diff line number Diff line
@@ -27,7 +27,10 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {

    override fun transformModel(model: Model): Model {
        // TODO: AWSRestJson determines the body from HTTP traits
        return OperationNormalizer().transformModel(model, OperationNormalizer.noBody)
        return OperationNormalizer(model).transformModel(
            inputBodyFactory = OperationNormalizer.NoBody,
            outputBodyFactory = OperationNormalizer.NoBody
        )
    }

    override fun support(): ProtocolSupport {
@@ -69,7 +72,8 @@ class AwsRestJsonGenerator(
            inputShape,
            httpTrait
        )
        val contentType = httpIndex.determineRequestContentType(operationShape, "application/json").orElse("application/json")
        val contentType =
            httpIndex.determineRequestContentType(operationShape, "application/json").orElse("application/json")
        httpBindingGenerator.renderUpdateHttpBuilder(implBlockWriter)
        httpBuilderFun(implBlockWriter) {
            write("let builder = \$T::new();", requestBuilder)
+38 −15
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
@@ -27,7 +28,10 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.SymbolMetadataProvider
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.util.dq

/**
@@ -41,13 +45,17 @@ class JsonSerializerSymbolProvider(
) :
    SymbolMetadataProvider(base) {

    data class SerdeConfig(val serialize: Boolean, val deserialize: Boolean)

    private fun MemberShape.serializedName() =
        this.getTrait(JsonNameTrait::class.java).map { it.value }.orElse(this.memberName)

    val httpIndex = HttpBindingIndex.of(model)
    val serializerBuilder = SerializerBuilder(base.config().runtimeConfig)
    private val httpIndex = HttpBindingIndex.of(model)
    private val serializerBuilder = SerializerBuilder(base.config().runtimeConfig)
    override fun memberMeta(memberShape: MemberShape): RustMetadata {
        val currentMeta = base.toSymbol(memberShape).expectRustMetadata()
        val serdeConfig = serdeRequired(model.expectShape(memberShape.container))
        if (serdeConfig.serialize) {
            val skipIfNone =
                if (base.toSymbol(memberShape).rustType().stripOuter<RustType.Reference>() is RustType.Option) {
                    listOf(Custom("serde(skip_serializing_if = \"Option::is_none\")"))
@@ -60,11 +68,26 @@ class JsonSerializerSymbolProvider(
                listOf(Custom("serde(serialize_with = ${serializer.fullyQualifiedName().dq()})", listOf(it)))
            } ?: listOf()
            return currentMeta.copy(additionalAttributes = currentMeta.additionalAttributes + renameAttribute + serdeAttribute + skipIfNone)
        } else {
            return currentMeta
        }
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        val currentMeta = base.toSymbol(structureShape).expectRustMetadata()
        return currentMeta.withDerive(RuntimeType.Serialize)
        val requiredSerde = serdeRequired(structureShape)
        return currentMeta
            .letIf(requiredSerde.serialize) { it.withDerive(RuntimeType.Serialize) }
        // TODO: generate deserializers
        // .letIf(requiredSerde.deserialize) { it.withDerive(RuntimeType.Deserialize) }
    }

    private fun serdeRequired(shape: Shape): SerdeConfig {
        return when {
            shape.hasTrait(InputBodyTrait::class.java) -> SerdeConfig(serialize = true, deserialize = false)
            shape.hasTrait(OutputBodyTrait::class.java) -> SerdeConfig(serialize = false, deserialize = true)
            else -> SerdeConfig(serialize = true, deserialize = true)
        }
    }

    override fun unionMeta(unionShape: UnionShape): RustMetadata {
Loading