Unverified Commit 9028c847 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Extract protocol test logic and improve tests in server generation (#798)

* Extract and improve protocol tests generation in their own class
parent edcc7c19
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
@@ -273,7 +274,12 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
        ServiceGenerator(
            rustCrate,
            protocolGenerator,
            protocolGeneratorFactory.support(),
            ProtocolSupport(
                requestDeserialization = true,
                requestBodyDeserialization = true,
                responseSerialization = true,
                errorSerialization = true
            ),
            codegenContext,
            codegenDecorator
        )
@@ -296,7 +302,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

                impl Error {
                    ##[allow(dead_code)]
                    fn generic(msg: &'static str) -> Self {
                    pub fn generic(msg: &'static str) -> Self {
                        Self::Generic(msg.into())
                    }
                }
+24 −0
Original line number Diff line number Diff line
@@ -22,6 +22,11 @@ import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecor
import java.util.logging.Level
import java.util.logging.Logger

/** Rust Codegen Plugin
 *  This is the entrypoint for code generation, triggered by the smithy-build plugin.
 *  `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which
 *  enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models.
 */
class RustCodegenPlugin : SmithyBuildPlugin {
    private val logger = Logger.getLogger(javaClass.name)

@@ -30,23 +35,42 @@ class RustCodegenPlugin : SmithyBuildPlugin {
    override fun execute(context: PluginContext) {
        // Suppress extremely noisy logs about reserved words
        Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF
        // Discover `RustCodegenDecorators` on the classpath. `RustCodegenDectorator` return different types of
        // customization. A customization is a function of:
        // - location (eg. the mutate section of an operation)
        // - context (eg. the of the operation)
        // - writer: The active RustWriter at the given location
        val codegenDecorator = CombinedCodegenDecorator.fromClasspath(context)

        // CodegenVistor is the main driver of code generation that traverses the model and generates code
        CodegenVisitor(context, codegenDecorator).execute()
    }

    companion object {
        /** SymbolProvider
         * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider
         *
         * The Symbol provider is composed of a base `SymbolVisitor` which handles the core funcitonality, then is layered
         * with other symbol providers, documented inline, to handle the full scope of Smithy types.
         */
        fun baseSymbolProvider(
            model: Model,
            serviceShape: ServiceShape,
            symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig
        ) =
            SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
                // Generate different types for EventStream shapes (eg. transcribe streaming)
                .let {
                    EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model)
                }
                // Generate `ByteStream` instead of `Blob` for streaming binary shapes (eg. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it) }
                // Streaming shapes need different derives (eg. they cannot derive Eq)
                .let { StreamingShapeMetadataProvider(it, model) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features eg. `send()` cannot
                // be the name of an operation input
                .let { RustReservedWordSymbolProvider(it) }
    }
}
+6 −1
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
@@ -14,7 +16,6 @@ import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfi
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.util.inputShape

class ServiceGenerator(
@@ -37,6 +38,10 @@ class ServiceGenerator(
                        operation,
                        decorator.operationCustomizations(context, operation, listOf())
                    )
                    // render protocol tests into `operation.rs` (note operationWriter vs.
                    // inputWriter)
                    ProtocolTestGenerator(context, protocolSupport, operation, operationWriter)
                        .render()
                }
            }
            rustCrate.withModule(RustModule.Error) { writer ->
+447 −0

File added.

Preview size limit exceeded, changes collapsed.

+2 −162
Original line number Diff line number Diff line
@@ -6,22 +6,14 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.IdempotencyTokenTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.protocoltests.traits.AppliesTo
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.CratesIo
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
@@ -47,13 +39,10 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
import software.amazon.smithy.rust.codegen.testutil.TokioTest
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findMemberWithTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@@ -90,7 +79,6 @@ class RestJson1HttpSerializerGenerator(
    fun render(writer: RustWriter, operationShape: OperationShape) {
        renderResponseSerializer(writer, operationShape)
        renderErrorSerializer(writer, operationShape)
        renderTests(writer, operationShape)
    }

    private fun renderResponseSerializer(writer: RustWriter, operationShape: OperationShape) {
@@ -121,7 +109,7 @@ class RestJson1HttpSerializerGenerator(
                    HttpLocation.RESPONSE_CODE -> {
                        rustTemplate(
                            """
                                let status = output.${it.memberName}
                                let status = output.${it.memberName.toLowerCase()}
                                    .ok_or(#{JsonSerdeError}::generic(${(it.member.memberName + " missing or empty").dq()}))?;
                                let http_status: u16 = #{Convert}::TryFrom::<i32>::try_from(status)
                                    .map_err(|_| #{JsonSerdeError}::generic(${("invalid status code").dq()}))?;
@@ -221,100 +209,6 @@ class RestJson1HttpSerializerGenerator(
            )
        }
    }

    private fun renderTests(writer: RustWriter, operationShape: OperationShape) {
        val operationIndex = OperationIndex.of(model)
        val outputShape = operationShape.outputShape(model)
        val responseTests =
            operationShape
                .getTrait<HttpResponseTestsTrait>()
                ?.getTestCasesFor(AppliesTo.SERVER)
                .orEmpty()
                .map { it to outputShape }
        val errorTests =
            operationIndex.getErrors(operationShape).flatMap { error ->
                error.getTrait<HttpResponseTestsTrait>()?.testCases.orEmpty().map {
                    it to error
                }
            }
        if (responseTests.isEmpty() && errorTests.isEmpty()) {
            return
        }
        val testModuleName = "serialize_${operationShape.id.name.toSnakeCase()}_test"
        val moduleMeta =
            RustMetadata(
                public = false,
                additionalAttributes =
                listOf(
                    Attribute.Cfg("test"),
                    Attribute.Custom(
                        "allow(unreachable_code, unused_variables)"
                    )
                )
            )
        writer.write("")
        writer.withModule(testModuleName, moduleMeta) {
            responseTests.forEach {
                renderSerializeResponseTestCase(operationShape, it.first, it.second)
            }
            errorTests.forEach {
                renderSerializeResponseTestCase(operationShape, it.first, it.second)
            }
        }
    }

    private fun RustWriter.renderSerializeResponseTestCase(
        operationShape: OperationShape,
        testCase: HttpResponseTestCase,
        shape: StructureShape
    ) {
        val isError = shape.hasTrait<ErrorTrait>()
        val fnName =
            if (isError) "serialize_${operationShape.id.name.toSnakeCase()}_error"
            else "serialize_${operationShape.id.name.toSnakeCase()}_response"
        val variantName =
            if (isError)
                "${format(operationShape.errorSymbol(symbolProvider))}Kind::${symbolProvider.toSymbol(shape).name}"
            else ""
        Attribute.Custom("test").render(this)
        rustBlock("fn ${testCase.id.toSnakeCase()}()") {
            rust("let config = #T::Config::builder().build();", RuntimeType.Config)
            writeInline("let expected = ")
            instantiator.render(this, shape, testCase.params)
            write(";")
            if (isError) {
                rust("let expected = $variantName(expected);")
            }
            rust(
                """let response = #T(&expected).expect("failed to serialize response");""",
                operation.member(fnName)
            )
            rust("assert_eq!(response.status(), ${testCase.code});")
            rustTemplate(
                """
                    use #{ParseHttpResponse};
                    let parser = #{op}::new();
                    let actual = parser.parse_loaded(&response);
                """.trimIndent(),
                *codegenScope,
                "op" to symbolProvider.toSymbol(operationShape),
            )
            if (isError) {
                rust("""let actual = actual.expect_err("failed to parse error");""")
                rust(
                    """
                        match (&expected, &actual.kind) {
                            ($variantName(expected), $variantName(actual)) => assert_eq!(expected, actual),
                            _ => panic!("incorrect error type"),
                        };
                    """.trimIndent()
                )
            } else {
                rust("""let actual = actual.expect("failed to parse error");""")
                rust("assert_eq!(expected, actual);")
            }
        }
    }
}

class RestJson1HttpDeserializerGenerator(
@@ -348,7 +242,7 @@ class RestJson1HttpDeserializerGenerator(

    fun render(writer: RustWriter, operationShape: OperationShape) {
        renderRequestDeserializer(writer, operationShape)
        renderRequestDeserializerTests(writer, operationShape)
        // renderRequestDeserializerTests(writer, operationShape)
    }

    private fun renderRequestDeserializer(writer: RustWriter, operationShape: OperationShape) {
@@ -562,58 +456,4 @@ class RestJson1HttpDeserializerGenerator(
        val memberName = binding.memberName.toSnakeCase()
        return "deser_label_${containerName}_$memberName"
    }

    private fun renderRequestDeserializerTests(writer: RustWriter, operationShape: OperationShape) {
        val testCases =
            operationShape.getTrait<HttpRequestTestsTrait>()?.getTestCasesFor(AppliesTo.SERVER)
                ?: return
        val testModuleName = "deser_${operationShape.id.name.toSnakeCase()}_test"
        val moduleMeta =
            RustMetadata(
                public = false,
                additionalAttributes =
                listOf(
                    Attribute.Cfg("test"),
                    Attribute.Custom(
                        "allow(unreachable_code, unused_variables)"
                    )
                )
            )
        writer.write("")
        writer.withModule(testModuleName, moduleMeta) {
            testCases.forEach { renderRequestDeserializerTestCase(it, operationShape) }
        }
    }

    private fun RustWriter.renderRequestDeserializerTestCase(testCase: HttpRequestTestCase, operationShape: OperationShape) {
        TokioTest.render(this)
        rustBlock("async fn ${testCase.id.toSnakeCase()}()") {
            val inputShape = operationShape.inputShape(model)
            val deserFnName = "deser_${operationShape.id.name.toSnakeCase()}_request"
            val customToken =
                inputShape.findMemberWithTrait<IdempotencyTokenTrait>(model)?.let {
                    """.make_token("00000000-0000-4000-8000-000000000000")"""
                }
                    ?: ""
            rust("let config = #T::Config::builder()$customToken.build();", RuntimeType.Config)
            writeInline("let expected = ")
            instantiator.render(this, inputShape, testCase.params)
            write(";")
            rust(
                """let op = expected.make_operation(&config).await.expect("failed to build operation");"""
            )
            rust("let (request, parts) = op.into_request_response().0.into_parts();")
            rustTemplate(
                "let request = request.map(|body| #{Bytes}::from(body.bytes().unwrap().to_vec()));",
                *codegenScope
            )
            rust("assert_eq!(request.method().as_str(), ${testCase.method.dq()});")
            rust("assert_eq!(request.uri().to_owned(), ${testCase.uri.dq()});")
            rust(
                """let actual = #T(&request).expect("failed to parse request");""",
                operation.member(deserFnName)
            )
            rust("assert_eq!(expected, actual);")
        }
    }
}