Unverified Commit 2e7cb459 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Fix #756: Improve server codegen by leveraging RuntimeType (#774)



* Fix SdkBody test to just use assert!

* Fixes #756: Improve server codegen by leveraging RuntimeType. Cleanup
server CodegenVisitor to only render protocol specific code for
OperationShapes.

* Fix silly error I introduced in the smithy-http SdkBody tests

* Apply suggestions from code review

Co-authored-by: default avatardavid-perez <d@vidp.dev>

* Extract Json ser/de into an interface. Refactor the CodegenVisitor
structureShape to generate both inputs and outputs outside of the
ProtocolGenerator.

* Gather error members from the HTTP index

* Use operationShape as input for server generation of input/output/error

Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent 8b5eb685
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -82,7 +82,7 @@ task("generateCargoWorkspace") {
}

tasks["smithyBuildJar"].dependsOn("generateSmithyBuild")

tasks["assemble"].dependsOn("smithyBuildJar")
tasks["assemble"].finalizedBy("generateCargoWorkspace")

tasks.register<Exec>("cargoCheck") {
@@ -112,7 +112,7 @@ tasks.register<Exec>("cargoDocs") {
tasks.register<Exec>("cargoClippy") {
    workingDir("build/smithyprojections/codegen-server-test/")
    // disallow warnings
    commandLine("cargo", "clippy")
    commandLine("cargo", "clippy", "--", "-D", "warnings")
    dependsOn("assemble")
}

+8 −8
Original line number Diff line number Diff line
@@ -41,20 +41,20 @@ resource Service {
@http(method: "PUT", uri: "/service/{id}")
@documentation("Service register operation")
operation RegisterService {
    input: RegisterServiceInput,
    output: RegisterServiceOutput,
    input: RegisterServiceInputRequest,
    output: RegisterServiceOutputResponse,
    errors: [ResourceAlreadyExists]
}

@documentation("Service register input structure")
structure RegisterServiceInput {
structure RegisterServiceInputRequest {
    @required
    @httpLabel
    id: ServiceId,
}

@documentation("Service register output structure")
structure RegisterServiceOutput {
structure RegisterServiceOutputResponse {
    @required
    id: ServiceId
}
@@ -63,16 +63,16 @@ structure RegisterServiceOutput {
@http(uri: "/healthcheck", method: "GET")
@documentation("Read-only healthcheck operation")
operation Healthcheck {
    input: HealthcheckInput,
    output: HealthcheckOutput
    input: HealthcheckInputRequest,
    output: HealthcheckOutputResponse
}

@documentation("Service healthcheck output structure")
structure HealthcheckInput {
structure HealthcheckInputRequest {

}

@documentation("Service healthcheck input structure")
structure HealthcheckOutput {
structure HealthcheckOutputResponse {

}
+63 −39
Original line number Diff line number Diff line
@@ -9,7 +9,6 @@ import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
@@ -41,13 +40,10 @@ import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
@@ -57,11 +53,13 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStream
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger

/**
 * Entrypoint for server-side code generation. This class will walk the in-memory model and
 * generate all the needed types by calling the accept() function on the available shapes.
 */
class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustCodegenDecorator) :
    ShapeVisitor.Default<Unit>() {

@@ -76,12 +74,9 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
    private val protocolGeneratorFactory: ProtocolGeneratorFactory<ProtocolGenerator>
    private val protocolGenerator: ProtocolGenerator

    private val serializerGenerator: JsonSerializerGenerator
    private val deserializerGenerator: JsonParserGenerator
    private val httpSerializerGenerator: ServerGenerator
    private val httpDeserializerGenerator: ServerGenerator
    private val httpBindingResolver: HttpBindingResolver
    private val renderedStructures = mutableSetOf<StructureShape>()

    init {
        val symbolVisitorConfig =
@@ -116,8 +111,6 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
                ProtocolContentTypes.consistent("application/json"),
            )

        serializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver)
        deserializerGenerator = JsonParserGenerator(codegenContext, httpBindingResolver)
        when (codegenContext.protocol) {
            RestJson1Trait.ID -> {
                httpSerializerGenerator =
@@ -126,24 +119,43 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
                    RestJson1HttpDeserializerGenerator(codegenContext, httpBindingResolver)
            }
            else -> {
                // TODO: support other protocols
                throw Exception("Protocol ${codegenContext.protocol} not support yet")
                TODO("Protocol ${codegenContext.protocol} not supported yet")
            }
        }
    }

    /**
     * Base model transformation applied to all services.
     * See below for details.
     */
    private fun baselineTransform(model: Model) =
        model
            // Add `Box<T>` to recursive shapes as necessary
            .let(RecursiveShapeBoxer::transform)
            // Normalize the `message` field on errors when enabled in settings (default: true)
            .letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
            // Normalize operations by adding synthetic input and output shapes to every operation
            .let(OperationNormalizer::transform)
            // Drop unsupported event stream operations from the model
            .let { RemoveEventStreamOperations.transform(it, settings) }
            // Normalize event stream operations
            .let(EventStreamNormalizer::transform)

    /**
     * Execute code generation
     *
     * 1. Load the service from `RustSettings`.
     * 2. Traverse every shape in the closure of the service.
     * 3. Loop through each shape and visit them (calling the override functions in this class)
     * 4. Call finalization tasks specified by decorators.
     * 5. Write the in-memory buffers out to files.
     *
     * The main work of code generation (serializers, protocols, etc.) is handled in `fn serviceShape` below.
     */
    fun execute() {
        val service = settings.getService(model)
        logger.info(
            "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}..."
            "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}"
        )
        val serviceShapes = Walker(model).walkShapes(service)
        serviceShapes.forEach { it.accept(this) }
@@ -166,48 +178,39 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
                "[rust-server-codegen] Failed to run cargo fmt: [${service.id}]\n${err.output}"
            )
        }

        logger.info("[rust-server-codegen] Rust server generation complete!")
    }

    private fun renderStructure(
        writer: RustWriter,
        structureShape: StructureShape,
        includedMembers: List<MemberShape>,
    ) {
        // TODO: review this deduplication mechanism as it doesn't feel very ergonomic
        if (renderedStructures.add(structureShape)) {
            serializerGenerator.renderStructure(writer, structureShape, includedMembers)
            deserializerGenerator.renderStructure(writer, structureShape, includedMembers)
        }
    }

    override fun getDefault(shape: Shape?) {}

    /**
     * Operation Shape Visitor
     *
     * For each operation shape, generate the corresponding protocol implementation.
     */
    override fun operationShape(shape: OperationShape?) {
        logger.info("[rust-server-codegen] Generating operation $shape...")
        logger.info("[rust-server-codegen] Generating operation $shape")
        if (shape != null) {
            val inputHttpDocumentMembers =
                httpBindingResolver.requestMembers(shape, HttpLocation.DOCUMENT)
            val outputHttpDocumentMembers =
                httpBindingResolver.responseMembers(shape, HttpLocation.DOCUMENT)
            rustCrate.useShapeWriter(shape) { writer ->
                shape.let {
                    httpDeserializerGenerator.render(writer, it)
                    httpSerializerGenerator.render(writer, it)
                    renderStructure(writer, shape.inputShape(model), inputHttpDocumentMembers)
                    renderStructure(writer, shape.outputShape(model), outputHttpDocumentMembers)
                    shape.errors.forEach { error ->
                        val errorShape = model.expectShape(error, StructureShape::class.java)
                        renderStructure(writer, errorShape, errorShape.members().toList())
                    }
                }
            }
        }
    }

    /**
     * Structure Shape Visitor
     *
     * For each structure shape, generate:
     * - A Rust structure for the shape (`StructureGenerator`).
     * - A builder for the shape.
     *
     * This function _does not_ generate any serializers.
     */
    override fun structureShape(shape: StructureShape) {
        logger.info("[rust-server-codegen] Generating a structure $shape...")
        logger.info("[rust-server-codegen] Generating a structure $shape")
        rustCrate.useShapeWriter(shape) { writer ->
            StructureGenerator(model, symbolProvider, writer, shape).render()
            if (!shape.hasTrait<SyntheticInputTrait>()) {
@@ -221,6 +224,11 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
        }
    }

    /**
     * String Shape Visitor
     *
     * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes.
     */
    override fun stringShape(shape: StringShape) {
        shape.getTrait<EnumTrait>()?.also { enum ->
            rustCrate.useShapeWriter(shape) { writer ->
@@ -229,12 +237,28 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
        }
    }

    /**
     * Union Shape Visitor
     *
     * Generate an `enum` for union shapes.
     *
     * This function _does not_ generate any serializers.
     */
    override fun unionShape(shape: UnionShape) {
        rustCrate.useShapeWriter(shape) {
            UnionGenerator(model, symbolProvider, it, shape).render()
        }
    }

    /**
     * Generate service-specific code for the model:
     * - Serializers
     * - Deserializers
     * - Fluent client
     * - Trait implementations
     * - Protocol tests
     * - Operation structures
     */
    override fun serviceShape(shape: ServiceShape) {
        ServiceGenerator(
            rustCrate,
+59 −38
Original line number Diff line number Diff line
@@ -44,8 +44,8 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.protocols.serializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
@@ -58,6 +58,10 @@ import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger

/**
* TODO: this abstract class is here temporarily as it promotes a bit of code reusal between the RestJson1
* serialzer and deserializer. This will be refactored and probably removed soon.
*/
abstract class ServerGenerator(
    codegenContext: CodegenContext,
    private val httpBindingResolver: HttpTraitHttpBindingResolver,
@@ -74,6 +78,8 @@ abstract class ServerGenerator(
    public val index = HttpBindingIndex.of(model)
    public val service = codegenContext.serviceShape
    public val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS
    public val jsonSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver)
    public val jsonParserGenerator = JsonParserGenerator(codegenContext, httpBindingResolver)

    abstract fun render(writer: RustWriter, operationShape: OperationShape)
}
@@ -108,7 +114,14 @@ class RestJson1HttpSerializerGenerator(
        val outputShape = operationShape.outputShape(model)
        if (outputShape.hasStreamingMember(model)) {
            logger.warning(
                "$operationShape: response serialization does not currently support streaming shapes"
                "[rust-server-codegen] $operationShape: response serialization does not currently support streaming shapes"
            )
            return
        }
        val serializerSymbol = jsonSerializerGenerator.serverOutputSerializer(operationShape)
        if (serializerSymbol == null) {
            logger.warning(
                "[rust-server-codegen] $outputShape: response output serialization does not contain any member"
            )
            return
        }
@@ -119,8 +132,6 @@ class RestJson1HttpSerializerGenerator(
            *codegenScope,
            "O" to outputSymbol,
        ) {
            val serializerSymbol =
                operation.member(symbolProvider.serializeFunctionName(outputShape))
            rust(
                "let payload = #T(output)?;",
                serializerSymbol,
@@ -143,7 +154,7 @@ class RestJson1HttpSerializerGenerator(
                    }
                    HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.PAYLOAD -> {
                        logger.warning(
                            "$operationShape: response serialization does not currently support $location bindings"
                            "[rust-server-codegen] $operationShape: response serialization does not currently support $location bindings"
                        )
                    }
                    else -> {}
@@ -177,12 +188,12 @@ class RestJson1HttpSerializerGenerator(
                    val errorTrait = variantShape.expectTrait<ErrorTrait>()
                    val variantSymbol = symbolProvider.toSymbol(variantShape)
                    val data = safeName("var")
                    val serializerSymbol =
                        serde.member(symbolProvider.serializeFunctionName(variantShape))
                    val serializerSymbol = jsonSerializerGenerator.serverErrorSerializer(it)
                    if (serializerSymbol != null) {
                        rustBlock("#TKind::${variantSymbol.name}($data) =>", errorSymbol) {
                            rust(
                                """
                                #T(&mut object, &$data);
                                    #T(&$data)?;
                                    object.key(${"code".dq()}).string(${httpBindingResolver.errorCode(variantShape).dq()});
                                """.trimIndent(),
                                serializerSymbol
@@ -212,6 +223,11 @@ class RestJson1HttpSerializerGenerator(
                                    ?: errorTrait.defaultHttpStatusCode
                            rust("response = response.status($status);")
                        }
                    } else {
                        logger.warning(
                            "[rust-server-codegen] $variantShape: response error serialization does not contain any member"
                        )
                    }
                }
                rust(
                    """
@@ -387,7 +403,14 @@ class RestJson1HttpRequestDeserializerGenerator(
        val inputShape = operationShape.inputShape(model)
        if (inputShape.hasStreamingMember(model)) {
            logger.warning(
                "$operationShape: request deserialization does not currently support streaming shapes"
                "[rust-server-codegen] $operationShape: request deserialization does not currently support streaming shapes"
            )
            return
        }
        val deserializerSymbol = jsonParserGenerator.serverInputParser(operationShape)
        if (deserializerSymbol == null) {
            logger.warning(
                "[rust-server-codegen] $inputShape: response output serialization does not contain any member"
            )
            return
        }
@@ -398,8 +421,6 @@ class RestJson1HttpRequestDeserializerGenerator(
            *codegenScope,
            "I" to inputSymbol,
        ) {
            val deserializerSymbol =
                operation.member(symbolProvider.deserializeFunctionName(inputShape))
            rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
            rust(
                "input = #T(request.body().as_ref(), input)?;",
@@ -428,7 +449,7 @@ class RestJson1HttpRequestDeserializerGenerator(
            }
            else -> {
                logger.warning(
                    "$operationShape: request deserialization does not currently support $location bindings"
                    "[rust-server-codegen] $operationShape: request deserialization does not currently support $location bindings"
                )
                null
            }
+0 −1
Original line number Diff line number Diff line
@@ -178,7 +178,6 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
     * String Shape Visitor
     *
     * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes.
     * For strings that have the enum trait attached,
     */
    override fun stringShape(shape: StringShape) {
        shape.getTrait<EnumTrait>()?.also { enum ->
Loading