From 26ff8021b59b93857743bfd91136001317291d71 Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Fri, 18 Dec 2020 12:24:37 -0500 Subject: [PATCH] Refactor HttpProtocolTestGeneratorTest (#90) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As the ServiceGenerator grew more complex, it was becoming tedious to keep the hand-written operation in the HttpProtocolTestGenerator updated. This diff replaces the hand-written operation with a custom protocol. To support this, rudimentary support for pluggable protocols was added to CodegenVisitor—eventually this will be driven by SPI. --- .../rust/codegen/smithy/CodegenVisitor.kt | 5 +- .../smithy/protocols/ProtocolLoader.kt | 6 +- .../HttpProtocolTestGeneratorTest.kt | 192 +++++++----------- .../amazon/smithy/rust/testutil/Rust.kt | 30 +++ 4 files changed, 107 insertions(+), 126 deletions(-) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt index 9ae10d389..cf0a729b0 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt @@ -27,13 +27,14 @@ import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader +import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.runCommand import java.util.logging.Logger -class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { +class CodegenVisitor(context: PluginContext, extraProtocols: ProtocolMap = mapOf()) : ShapeVisitor.Default() { private val logger = Logger.getLogger(javaClass.name) private val settings = RustSettings.from(context.model, context.settings) @@ -51,7 +52,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig, codegenConfig = settings.codegenConfig) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) - val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service) + val (protocol, generator) = ProtocolLoader.withAdditional(extraProtocols).protocolFor(context.model, service) protocolGenerator = generator model = generator.transformModel(baseModel) val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt index 2fef20cc3..2d7be5998 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/ProtocolLoader.kt @@ -17,8 +17,11 @@ import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory +typealias ProtocolMap = Map> + +// typealias ProtocolMap = // TODO: supportedProtocols must be runtime loadable via SPI; 2d -class ProtocolLoader(private val supportedProtocols: Map>) { +class ProtocolLoader(private val supportedProtocols: ProtocolMap) { fun protocolFor( model: Model, serviceShape: ServiceShape @@ -39,5 +42,6 @@ class ProtocolLoader(private val supportedProtocols: Map #T>", RuntimeType.Http("request::Request") - ) { - write("#T::assemble(self.input.request_builder_base(), self.input.build_body())", inputSymbol) + // A stubbed test protocol to do enable testing intentionally broken protocols + class TestProtocol(protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) { + override fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) { + fromResponseFun(implBlockWriter, operationShape) { + writeWithNoFormatting(correctResponse) } + } - rustBlock("pub fn new(input: #T) -> Self", inputSymbol) { - write("Self { input }") + override fun toBodyImpl( + implBlockWriter: RustWriter, + inputShape: StructureShape, + inputBody: StructureShape? + ) { + bodyBuilderFun(implBlockWriter) { + writeWithNoFormatting(body) } + } - rustBlock( - "pub fn from_response(_response: &#T>) -> Result<#T, #T>", - - RuntimeType.Http("response::Response"), - symbolProvider.toSymbol(operationShape.outputShape(model)), - operationShape.errorSymbol(symbolProvider) - ) { - writeWithNoFormatting(correctResponse) + override fun toHttpRequestImpl( + implBlockWriter: RustWriter, + operationShape: OperationShape, + inputShape: StructureShape + ) { + httpBuilderFun(implBlockWriter) { + write("#T::new()", RuntimeType.HttpRequestBuilder) + writeWithNoFormatting(httpRequestBuilder) } } } - writer.withModule("output") { - val outputShape = operationShape.outputShape(model) - outputShape.renderWithModelBuilder(model, symbolProvider, this) - } - writer.withModule("input") { - StructureGenerator(model, symbolProvider, this, shape).render() - builderGenerator.render(this) - rustBlock("impl SayHelloInput") { - builderGenerator.renderConvenienceMethod(this) - rustBlock("pub fn request_builder_base(&self) -> #T", RuntimeType.HttpRequestBuilder) { - write("#T::new()", RuntimeType.HttpRequestBuilder) - write(httpRequestBuilder) - } - rustBlock("pub fn build_body(&self) -> String") { - write(body) - } - rustBlock( - "pub fn assemble>(builder: #1T, body: T) -> #2T<#3T>", - RuntimeType.HttpRequestBuilder, - RuntimeType.Http("request::Request"), - RuntimeType.ByteSlab - ) { - write("let body = body.into();") - write("builder.header(#T, body.len()).body(body)", RuntimeType.Http("header::CONTENT_LENGTH")) - write(""".expect("http request should be valid")""") - } + + class TestProtocolFactory : ProtocolGeneratorFactory { + override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpProtocolGenerator { + return TestProtocol(protocolConfig) + } + + override fun transformModel(model: Model): Model { + return OperationNormalizer(model).transformModel( + inputBodyFactory = OperationNormalizer.NoBody, + outputBodyFactory = OperationNormalizer.NoBody + ) + } + + override fun support(): ProtocolSupport { + return ProtocolSupport(true, true, true) } - val protocolConfig = ProtocolConfig( - model, - symbolProvider, - runtimeConfig, - model.lookup("com.example#HelloService"), - RestJson1Trait.ID - ) - HttpProtocolTestGenerator( - protocolConfig, - ProtocolSupport( - requestBodySerialization = true, - responseDeserialization = true, - errorDeserialization = false - ), - model.lookup("com.example#SayHello"), - this - ).render() } + + val (pluginContext, testDir) = generatePluginContext(model) + // Intentionally shadow the builtin implementation of RestJson1 with our fake protocol + val visitor = CodegenVisitor(pluginContext, mapOf(RestJson1Trait.ID to TestProtocolFactory())) + visitor.execute() + return testDir } @Test fun `passing e2e protocol request test`() { - val writer = RustWriter.root() - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=Hello%20there&required") .header("X-Greeting", "Hi") @@ -212,16 +175,14 @@ class HttpProtocolTestGeneratorTest { """ ) - val testOutput = writer.compileAndTest() + val testOutput = "cargo test".runCommand(path) // Verify the test actually ran testOutput shouldContain "say_hello_request ... ok" } @Test fun `test incorrect response parsing`() { - val writer = RustWriter.root() - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=Hello%20there&required") .header("X-Greeting", "Hi") @@ -230,7 +191,7 @@ class HttpProtocolTestGeneratorTest { correctResponse = "Ok(SayHelloOutput::builder().build())" ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } err.message shouldContain "basic_response_test_response ... FAILED" @@ -238,21 +199,17 @@ class HttpProtocolTestGeneratorTest { @Test fun `test invalid body`() { - val writer = RustWriter.root() - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=Hello%20there&required") .header("X-Greeting", "Hi") .method("POST") """, - """ - "{}".to_string() -""" + """"{}".to_string().into()""" ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } err.message shouldContain "say_hello_request ... FAILED" @@ -261,11 +218,9 @@ class HttpProtocolTestGeneratorTest { @Test fun `test invalid url parameter`() { - val writer = RustWriter.root() // Hard coded implementation for this 1 test - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=INCORRECT&required") .header("X-Greeting", "Hi") @@ -274,7 +229,7 @@ class HttpProtocolTestGeneratorTest { ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" @@ -283,11 +238,7 @@ class HttpProtocolTestGeneratorTest { @Test fun `test forbidden url parameter`() { - val writer = RustWriter.root() - - // Hard coded implementation for this 1 test - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?goodbye&Hi=Hello%20there&required") .header("X-Greeting", "Hi") @@ -296,7 +247,7 @@ class HttpProtocolTestGeneratorTest { ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" @@ -305,11 +256,8 @@ class HttpProtocolTestGeneratorTest { @Test fun `test required url parameter`() { - val writer = RustWriter.root() - // Hard coded implementation for this 1 test - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=Hello%20there") .header("X-Greeting", "Hi") @@ -318,7 +266,7 @@ class HttpProtocolTestGeneratorTest { ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } // Verify the test actually ran err.message shouldContain "say_hello_request ... FAILED" @@ -327,9 +275,7 @@ class HttpProtocolTestGeneratorTest { @Test fun `invalid header`() { - val writer = RustWriter.root() - writeHttpImpl( - writer, + val path = generateService( """ .uri("/?Hi=Hello%20there&required") // should be "Hi" @@ -339,7 +285,7 @@ class HttpProtocolTestGeneratorTest { ) val err = assertThrows { - writer.compileAndTest(expectFailure = true) + "cargo test".runCommand(path) } err.message shouldContain "say_hello_request ... FAILED" err.message shouldContain "invalid header value" diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt index af10ac84b..1a051a857 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt @@ -8,8 +8,11 @@ package software.amazon.smithy.rust.testutil import com.moandjiezana.toml.TomlWriter import org.intellij.lang.annotations.Language import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.lang.CargoDependency @@ -94,6 +97,33 @@ object TestWorkspace { } } +/** + * Generates a test plugin context for [model] and returns the plugin context and the path it is rooted it. + * + * Example: + * ```kotlin + * val (pluginContext, path) = generatePluginContext(model) + * CodegenVisitor(pluginContext).execute() + * "cargo test".runCommand(path) + * ``` + */ +fun generatePluginContext(model: Model): Pair { + val testDir = TestWorkspace.subproject() + val moduleName = "test_${testDir.nameWithoutExtension}" + val testPath = testDir.toPath() + val manifest = FileManifest.create(testPath) + val settings = Node.objectNodeBuilder() + .withMember("module", Node.from(moduleName)) + .withMember("moduleVersion", Node.from("1.0.0")) + .withMember( + "runtimeConfig", + Node.objectNodeBuilder().withMember("relativePath", Node.from(TestRuntimeConfig.relativePath)).build() + ) + .build() + val pluginContext = PluginContext.builder().model(model).fileManifest(manifest).settings(settings).build() + return pluginContext to testPath +} + fun RustWriter.unitTest( @Language("Rust", prefix = "fn test() {", suffix = "}") test: String, name: String? = null -- GitLab