Unverified Commit 26ff8021 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Refactor HttpProtocolTestGeneratorTest (#90)

As the ServiceGenerator grew more complex, it was becoming tedious to keep the hand-written operation in the HttpProtocolTestGenerator updated.

This diff replaces the hand-written operation with a custom protocol. To support this, rudimentary support for pluggable protocols was added to CodegenVisitor—eventually this will be driven by SPI.
parent c0aa43c6
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -27,13 +27,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger

class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
class CodegenVisitor(context: PluginContext, extraProtocols: ProtocolMap = mapOf()) : ShapeVisitor.Default<Unit>() {

    private val logger = Logger.getLogger(javaClass.name)
    private val settings = RustSettings.from(context.model, context.settings)
@@ -51,7 +52,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
            SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig, codegenConfig = settings.codegenConfig)
        val baseModel = baselineTransform(context.model)
        val service = settings.getService(baseModel)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        val (protocol, generator) = ProtocolLoader.withAdditional(extraProtocols).protocolFor(context.model, service)
        protocolGenerator = generator
        model = generator.transformModel(baseModel)
        val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig)
+5 −1
Original line number Diff line number Diff line
@@ -17,8 +17,11 @@ import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory

typealias ProtocolMap = Map<ShapeId, ProtocolGeneratorFactory<HttpProtocolGenerator>>

// typealias ProtocolMap =
// TODO: supportedProtocols must be runtime loadable via SPI; 2d
class ProtocolLoader(private val supportedProtocols: Map<ShapeId, ProtocolGeneratorFactory<HttpProtocolGenerator>>) {
class ProtocolLoader(private val supportedProtocols: ProtocolMap) {
    fun protocolFor(
        model: Model,
        serviceShape: ServiceShape
@@ -39,5 +42,6 @@ class ProtocolLoader(private val supportedProtocols: Map<ShapeId, ProtocolGenera
            RestJson1Trait.ID to AwsRestJsonFactory()
        )
        val Default = ProtocolLoader(Protocols)
        fun withAdditional(protocols: ProtocolMap) = ProtocolLoader(Protocols + protocols)
    }
}
+69 −123
Original line number Diff line number Diff line
@@ -9,24 +9,22 @@ import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.CodegenVisitor
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.util.runCommand
import software.amazon.smithy.rust.testutil.asSmithyModel
import software.amazon.smithy.rust.testutil.compileAndTest
import software.amazon.smithy.rust.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.testutil.testSymbolProvider
import software.amazon.smithy.rust.testutil.generatePluginContext
import java.nio.file.Path

class HttpProtocolTestGeneratorTest {
    private val baseModel = """
    private val model = """
        namespace com.example

        use aws.protocols#restJson1
@@ -100,111 +98,76 @@ class HttpProtocolTestGeneratorTest {
            name: String
        }
    """.asSmithyModel()
    private val model = OperationNormalizer(baseModel).transformModel(
        inputBodyFactory = OperationNormalizer.NoBody,
        outputBodyFactory = OperationNormalizer.NoBody
    )
    private val symbolProvider = testSymbolProvider(model)
    private val runtimeConfig = TestRuntimeConfig
    private val correctBody = """{"name": "Teddy"}"""

    /**
     * Creates an fake HTTP implementation for SayHello & generates the protocol test
     *
     * Returns the [Path] the service was generated at, suitable for running `cargo test`
     */
    private fun writeHttpImpl(
        writer: RustWriter,
    private fun generateService(
        httpRequestBuilder: String,
        body: String = "${correctBody.dq()}.to_string()",
        body: String = "${correctBody.dq()}.to_string().into()",
        correctResponse: String = """Ok(SayHelloOutput::builder().value("hey there!").build())"""
    ) {
        val shape: StructureShape = model.lookup("com.example#SayHelloInput")
        val inputSymbol = symbolProvider.toSymbol(shape)
        val operationShape: OperationShape = model.lookup("com.example#SayHello")
        val builderGenerator = OperationInputBuilderGenerator(model, symbolProvider, model.lookup("com.example#SayHello"))
        writer.withModule("error") {
            StructureGenerator(model, symbolProvider, this, model.lookup("com.example#BadRequest")).render()
            CombinedErrorGenerator(model, symbolProvider, operationShape).render(this)
        }
        writer.withModule("operation") {
            rustBlock("pub struct SayHello") {
                write("input: #T", inputSymbol)
            }
            implBlock(operationShape, symbolProvider) {
                builderGenerator.renderConvenienceMethod(this)

                rustBlock(
                    "pub fn build_http_request(&self) -> #T<Vec<u8>>", RuntimeType.Http("request::Request")
                ) {
                    write("#T::assemble(self.input.request_builder_base(), self.input.build_body())", inputSymbol)
                }

                rustBlock("pub fn new(input: #T) -> Self", inputSymbol) {
                    write("Self { input }")
                }

                rustBlock(
                    "pub fn from_response(_response: &#T<impl AsRef<[u8]>>) -> Result<#T, #T>",
    ): Path {

                    RuntimeType.Http("response::Response"),
                    symbolProvider.toSymbol(operationShape.outputShape(model)),
                    operationShape.errorSymbol(symbolProvider)
                ) {
        // A stubbed test protocol to do enable testing intentionally broken protocols
        class TestProtocol(protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) {
            override fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) {
                fromResponseFun(implBlockWriter, operationShape) {
                    writeWithNoFormatting(correctResponse)
                }
            }

            override fun toBodyImpl(
                implBlockWriter: RustWriter,
                inputShape: StructureShape,
                inputBody: StructureShape?
            ) {
                bodyBuilderFun(implBlockWriter) {
                    writeWithNoFormatting(body)
                }
        writer.withModule("output") {
            val outputShape = operationShape.outputShape(model)
            outputShape.renderWithModelBuilder(model, symbolProvider, this)
            }
        writer.withModule("input") {
            StructureGenerator(model, symbolProvider, this, shape).render()
            builderGenerator.render(this)
            rustBlock("impl SayHelloInput") {
                builderGenerator.renderConvenienceMethod(this)
                rustBlock("pub fn request_builder_base(&self) -> #T", RuntimeType.HttpRequestBuilder) {

            override fun toHttpRequestImpl(
                implBlockWriter: RustWriter,
                operationShape: OperationShape,
                inputShape: StructureShape
            ) {
                httpBuilderFun(implBlockWriter) {
                    write("#T::new()", RuntimeType.HttpRequestBuilder)
                    write(httpRequestBuilder)
                    writeWithNoFormatting(httpRequestBuilder)
                }
                rustBlock("pub fn build_body(&self) -> String") {
                    write(body)
            }
                rustBlock(
                    "pub fn assemble<T: Into<#3T>>(builder: #1T, body: T) -> #2T<#3T>",
                    RuntimeType.HttpRequestBuilder,
                    RuntimeType.Http("request::Request"),
                    RuntimeType.ByteSlab
                ) {
                    write("let body = body.into();")
                    write("builder.header(#T, body.len()).body(body)", RuntimeType.Http("header::CONTENT_LENGTH"))
                    write(""".expect("http request should be valid")""")
        }

        class TestProtocolFactory : ProtocolGeneratorFactory<HttpProtocolGenerator> {
            override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpProtocolGenerator {
                return TestProtocol(protocolConfig)
            }
            val protocolConfig = ProtocolConfig(
                model,
                symbolProvider,
                runtimeConfig,
                model.lookup("com.example#HelloService"),
                RestJson1Trait.ID

            override fun transformModel(model: Model): Model {
                return OperationNormalizer(model).transformModel(
                    inputBodyFactory = OperationNormalizer.NoBody,
                    outputBodyFactory = OperationNormalizer.NoBody
                )
            HttpProtocolTestGenerator(
                protocolConfig,
                ProtocolSupport(
                    requestBodySerialization = true,
                    responseDeserialization = true,
                    errorDeserialization = false
                ),
                model.lookup("com.example#SayHello"),
                this
            ).render()
            }

            override fun support(): ProtocolSupport {
                return ProtocolSupport(true, true, true)
            }
        }

        val (pluginContext, testDir) = generatePluginContext(model)
        // Intentionally shadow the builtin implementation of RestJson1 with our fake protocol
        val visitor = CodegenVisitor(pluginContext, mapOf(RestJson1Trait.ID to TestProtocolFactory()))
        visitor.execute()
        return testDir
    }

    @Test
    fun `passing e2e protocol request test`() {
        val writer = RustWriter.root()
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
@@ -212,16 +175,14 @@ class HttpProtocolTestGeneratorTest {
                """
        )

        val testOutput = writer.compileAndTest()
        val testOutput = "cargo test".runCommand(path)
        // Verify the test actually ran
        testOutput shouldContain "say_hello_request ... ok"
    }

    @Test
    fun `test incorrect response parsing`() {
        val writer = RustWriter.root()
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
@@ -230,7 +191,7 @@ class HttpProtocolTestGeneratorTest {
            correctResponse = "Ok(SayHelloOutput::builder().build())"
        )
        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }

        err.message shouldContain "basic_response_test_response ... FAILED"
@@ -238,21 +199,17 @@ class HttpProtocolTestGeneratorTest {

    @Test
    fun `test invalid body`() {
        val writer = RustWriter.root()
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
                    .method("POST")
                """,
            """
                    "{}".to_string()
"""
            """"{}".to_string().into()"""
        )

        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }

        err.message shouldContain "say_hello_request ... FAILED"
@@ -261,11 +218,9 @@ class HttpProtocolTestGeneratorTest {

    @Test
    fun `test invalid url parameter`() {
        val writer = RustWriter.root()

        // Hard coded implementation for this 1 test
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=INCORRECT&required")
                    .header("X-Greeting", "Hi")
@@ -274,7 +229,7 @@ class HttpProtocolTestGeneratorTest {
        )

        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }
        // Verify the test actually ran
        err.message shouldContain "say_hello_request ... FAILED"
@@ -283,11 +238,7 @@ class HttpProtocolTestGeneratorTest {

    @Test
    fun `test forbidden url parameter`() {
        val writer = RustWriter.root()

        // Hard coded implementation for this 1 test
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?goodbye&Hi=Hello%20there&required")
                    .header("X-Greeting", "Hi")
@@ -296,7 +247,7 @@ class HttpProtocolTestGeneratorTest {
        )

        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }
        // Verify the test actually ran
        err.message shouldContain "say_hello_request ... FAILED"
@@ -305,11 +256,8 @@ class HttpProtocolTestGeneratorTest {

    @Test
    fun `test required url parameter`() {
        val writer = RustWriter.root()

        // Hard coded implementation for this 1 test
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=Hello%20there")
                    .header("X-Greeting", "Hi")
@@ -318,7 +266,7 @@ class HttpProtocolTestGeneratorTest {
        )

        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }
        // Verify the test actually ran
        err.message shouldContain "say_hello_request ... FAILED"
@@ -327,9 +275,7 @@ class HttpProtocolTestGeneratorTest {

    @Test
    fun `invalid header`() {
        val writer = RustWriter.root()
        writeHttpImpl(
            writer,
        val path = generateService(
            """
                    .uri("/?Hi=Hello%20there&required")
                    // should be "Hi"
@@ -339,7 +285,7 @@ class HttpProtocolTestGeneratorTest {
        )

        val err = assertThrows<CommandFailed> {
            writer.compileAndTest(expectFailure = true)
            "cargo test".runCommand(path)
        }
        err.message shouldContain "say_hello_request ... FAILED"
        err.message shouldContain "invalid header value"
+30 −0
Original line number Diff line number Diff line
@@ -8,8 +8,11 @@ package software.amazon.smithy.rust.testutil
import com.moandjiezana.toml.TomlWriter
import org.intellij.lang.annotations.Language
import software.amazon.smithy.build.FileManifest
import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.lang.CargoDependency
@@ -94,6 +97,33 @@ object TestWorkspace {
    }
}

/**
 * Generates a test plugin context for [model] and returns the plugin context and the path it is rooted it.
 *
 * Example:
 * ```kotlin
 * val (pluginContext, path) = generatePluginContext(model)
 * CodegenVisitor(pluginContext).execute()
 * "cargo test".runCommand(path)
 * ```
 */
fun generatePluginContext(model: Model): Pair<PluginContext, Path> {
    val testDir = TestWorkspace.subproject()
    val moduleName = "test_${testDir.nameWithoutExtension}"
    val testPath = testDir.toPath()
    val manifest = FileManifest.create(testPath)
    val settings = Node.objectNodeBuilder()
        .withMember("module", Node.from(moduleName))
        .withMember("moduleVersion", Node.from("1.0.0"))
        .withMember(
            "runtimeConfig",
            Node.objectNodeBuilder().withMember("relativePath", Node.from(TestRuntimeConfig.relativePath)).build()
        )
        .build()
    val pluginContext = PluginContext.builder().model(model).fileManifest(manifest).settings(settings).build()
    return pluginContext to testPath
}

fun RustWriter.unitTest(
    @Language("Rust", prefix = "fn test() {", suffix = "}") test: String,
    name: String? = null