Unverified Commit f528c523 authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

ServerInstantiator without builder (#3094)



Build structures without the builder

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent 22215502
Loading
Loading
Loading
Loading
+48 −9
Original line number Diff line number Diff line
@@ -46,6 +46,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
@@ -89,6 +90,7 @@ open class Instantiator(
    /** Fill out required fields with a default value. **/
    private val defaultsForRequiredFields: Boolean = false,
    private val customizations: List<InstantiatorCustomization> = listOf(),
    private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER,
) {
    data class Ctx(
        // The `http` crate requires that headers be lowercase, but Smithy protocol tests
@@ -96,6 +98,16 @@ open class Instantiator(
        val lowercaseMapKeys: Boolean = false,
    )

    /**
     * A struct can be built by:
     * * direct instantiation: A { field_1: value_1, field_2: value_2 }
     * * its associated builder: A::builder().field_1(value_1).field_2(value_2).build()
     */
    enum class InstantiatorConstructPattern {
        DIRECT,
        BUILDER,
    }

    /**
     * Client and server structures have different builder types. `Instantiator` needs to know how the builder
     * type behaves to generate code for it.
@@ -274,6 +286,12 @@ open class Instantiator(
        headers: Map<String, String>,
        ctx: Ctx,
    ) {
        when (constructPattern) {
            InstantiatorConstructPattern.DIRECT ->
                writer.withBlockTemplate("#{T} {", "}", "T" to symbolProvider.toSymbol(shape)) {
                    renderStructureMembers(writer, shape, data, headers, ctx)
                }
            InstantiatorConstructPattern.BUILDER -> {
                writer.rust("#T::builder()", symbolProvider.toSymbol(shape))

                renderStructureMembers(writer, shape, data, headers, ctx)
@@ -283,6 +301,8 @@ open class Instantiator(
                    writer.rust(".unwrap()")
                }
            }
        }
    }

    protected fun renderStructureMembers(
        writer: RustWriter,
@@ -291,12 +311,24 @@ open class Instantiator(
        headers: Map<String, String>,
        ctx: Ctx,
    ) {
        val renderedMembers = mutableSetOf<MemberShape>()
        fun renderMemberHelper(memberShape: MemberShape, value: Node) {
            renderedMembers.add(memberShape)
            when (constructPattern) {
                InstantiatorConstructPattern.DIRECT -> {
                    val fieldName = symbolProvider.toMemberName(memberShape)
                    writer.withBlock("$fieldName:", ",") {
                        renderMember(this, memberShape, value, ctx)
                    }
                }
                InstantiatorConstructPattern.BUILDER -> {
                    val setterName = builderKindBehavior.setterName(memberShape)
                    writer.withBlock(".$setterName(", ")") {
                        renderMember(this, memberShape, value, ctx)
                    }
                }
            }
        }

        if (defaultsForRequiredFields) {
            shape.allMembers.entries
@@ -313,7 +345,7 @@ open class Instantiator(
                .filter { it.value.hasTrait<HttpHeaderTrait>() }
                .forEach { (_, value) ->
                    val trait = value.expectTrait<HttpHeaderTrait>().value
                    headers.get(trait)?.let { renderMemberHelper(value, Node.from(it)) }
                    headers[trait]?.let { renderMemberHelper(value, Node.from(it)) }
                }
        }

@@ -331,6 +363,13 @@ open class Instantiator(
            ?.let {
                renderMemberHelper(it.value, fillDefaultValue(model.expectShape(it.value.target)))
            }

        if (constructPattern == InstantiatorConstructPattern.DIRECT) {
            val membersToRender = shape.allMembers.values.minus(renderedMembers)
            check(membersToRender.all { it.isOptional })
            membersToRender
                .forEach { renderMemberHelper(it, Node.nullNode()) }
        }
    }

    /**
+9 −9
Original line number Diff line number Diff line
@@ -68,14 +68,14 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat
        codegenContext.symbolProvider.toSymbol(memberShape).isOptional()
}

fun serverInstantiator(codegenContext: CodegenContext) =
    Instantiator(
class ServerInstantiator(codegenContext: CodegenContext) : Instantiator(
    codegenContext.symbolProvider,
    codegenContext.model,
    codegenContext.runtimeConfig,
    ServerBuilderKindBehavior(codegenContext),
    defaultsForRequiredFields = true,
    customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)),
    constructPattern = InstantiatorConstructPattern.DIRECT,
)

class ServerBuilderInstantiator(
+2 −2
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.serverInstantiator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerInstantiator
import java.util.logging.Logger
import kotlin.reflect.KFunction1

@@ -94,7 +94,7 @@ class ServerProtocolTestGenerator(
        inputT to outputT
    }

    private val instantiator = serverInstantiator(codegenContext)
    private val instantiator = ServerInstantiator(codegenContext)

    private val codegenScope = arrayOf(
        "Bytes" to RuntimeType.Bytes,
+1 −1
Original line number Diff line number Diff line
@@ -179,7 +179,7 @@ class ConstrainedCollectionGeneratorTest {
            project.withModule(ServerRustModule.Model) {
                render(codegenContext, this, shape)

                val instantiator = serverInstantiator(codegenContext)
                val instantiator = ServerInstantiator(codegenContext)
                for ((idx, validList) in testCase.validLists.withIndex()) {
                    val shapeNameIdx = "${shapeName}_$idx"
                    val buildValidFnName = "build_valid_$shapeNameIdx"
+1 −1
Original line number Diff line number Diff line
@@ -81,7 +81,7 @@ class ConstrainedMapGeneratorTest {
        project.withModule(ServerRustModule.Model) {
            render(codegenContext, this, constrainedMapShape)

            val instantiator = serverInstantiator(codegenContext)
            val instantiator = ServerInstantiator(codegenContext)
            rustBlock("##[cfg(test)] fn build_valid_map() -> std::collections::HashMap<String, String>") {
                instantiator.render(this, constrainedMapShape, testCase.validMap)
            }
Loading