From 1d30ad06ca743395c41337b43ccf71f0dea6465f Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Wed, 4 Nov 2020 11:17:19 -0500 Subject: [PATCH] Add initial protocol test generation (#16) * Implement Instantiator * Add support for dev-dependencies * Pass the protocol in Protocol config * Add initial protocol test generator This commit adds an initial and very limited implementation of protocol test generation. Specifically, it only asserts that the query string contains certain required parameters. We will update protocol test generation in concert with updates to the code generation features. * Fixes from integration tests * Add test of recursive maps, fix implementation of empty maps * Add protocol test generator, pre-commit hooks * Fix test * Fix formatting --- codegen-test/build.gradle.kts | 10 +- .../rust/codegen/lang/RustDependency.kt | 11 +- .../rust/codegen/smithy/CodegenVisitor.kt | 7 +- .../rust/codegen/smithy/RuntimeTypes.kt | 10 + .../smithy/generators/CargoTomlGenerator.kt | 15 +- .../generators/HttpProtocolGenerator.kt | 4 +- .../generators/HttpProtocolTestGenerator.kt | 66 ++++++ .../codegen/smithy/generators/Instantiator.kt | 197 ++++++++++++++++++ .../smithy/generators/ServiceGenerator.kt | 9 +- .../smithy/generators/StructureGenerator.kt | 25 ++- .../amazon/smithy/rust/codegen/util/Exec.kt | 13 +- .../amazon/smithy/rust/codegen/util/Smithy.kt | 9 + .../HttpProtocolTestGeneratorTest.kt | 151 ++++++++++++++ .../smithy/generators/InstantiatorTest.kt | 165 +++++++++++++++ .../amazon/smithy/rust/testutil/Rust.kt | 42 ++-- .../smithy/rust/testutil/TestHelpers.kt | 2 +- rust-runtime/smithy-types/src/lib.rs | 6 + 17 files changed, 705 insertions(+), 37 deletions(-) create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt create mode 100644 codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt create mode 100644 codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt diff --git a/codegen-test/build.gradle.kts b/codegen-test/build.gradle.kts index 8e6a6a69e..7efd698c7 100644 --- a/codegen-test/build.gradle.kts +++ b/codegen-test/build.gradle.kts @@ -93,6 +93,14 @@ tasks.register("cargoCheck") { dependsOn("build") } +tasks.register("cargoTest") { + workingDir("build/smithyprojections/codegen-test/") + // disallow warnings + environment("RUSTFLAGS", "-D warnings") + commandLine("cargo", "test") + dependsOn("build") +} + tasks.register("cargoClippy") { workingDir("build/smithyprojections/codegen-test/") // disallow warnings @@ -101,7 +109,7 @@ tasks.register("cargoClippy") { dependsOn("build") } -tasks["test"].finalizedBy("cargoCheck", "cargoClippy") +tasks["test"].finalizedBy("cargoCheck", "cargoClippy", "cargoTest") tasks["clean"].doFirst { delete("smithy-build.json") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt index cce8c3b3a..7763ed679 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt @@ -9,13 +9,18 @@ import software.amazon.smithy.codegen.core.SymbolDependency import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +sealed class DependencyScope +object Dev : DependencyScope() +object Compile : DependencyScope() + sealed class DependencyLocation data class CratesIo(val version: String) : DependencyLocation() data class Local(val path: String? = null) : DependencyLocation() data class RustDependency( val name: String, - val location: DependencyLocation + val location: DependencyLocation, + val scope: DependencyScope = Compile ) : SymbolDependencyContainer { override fun getDependencies(): List { return listOf( @@ -44,6 +49,10 @@ data class RustDependency( "${runtimeConfig.cratePrefix}-http", Local(runtimeConfig.relativePath) ) + fun ProtocolTestHelpers(runtimeConfig: RuntimeConfig) = RustDependency( + "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = Dev + ) + private val PropKey = "rustdep" fun fromSymbolDependency(symbolDependency: SymbolDependency) = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt index 08b5b1376..86ae3566d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt @@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.runCommand private val PublicModules = listOf("error", "operation", "model") @@ -71,7 +72,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { LibRsGenerator(modules, it).render() } writers.flushWriters() - "cargo fmt".runCommand(fileManifest.baseDir) + try { + "cargo fmt".runCommand(fileManifest.baseDir) + } catch (_: CommandFailed) { + logger.warning("Generated output did not parse") + } } override fun getDefault(shape: Shape?) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index 8703add78..c3df3fd49 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -40,12 +40,14 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na // TODO: refactor to be RuntimeTypeProvider a la Symbol provider that packages the `RuntimeConfig` state. companion object { + // val Blob = RuntimeType("Blob", RustDependency.IO_CORE, "blob") val From = RuntimeType("From", dependency = null, namespace = "std::convert") val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert") fun StdFmt(member: String) = RuntimeType("fmt::$member", dependency = null, namespace = "std") val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") val HashSet = RuntimeType("HashSet", dependency = null, namespace = "std::collections") + val HashMap = RuntimeType("HashMap", dependency = null, namespace = "std::collections") fun Instant(runtimeConfig: RuntimeConfig) = RuntimeType("Instant", RustDependency.SmithyTypes(runtimeConfig), "${runtimeConfig.cratePrefix}_types") @@ -62,6 +64,9 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na fun Base64Encode(runtimeConfig: RuntimeConfig): RuntimeType = RuntimeType("encode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64") + fun Base64Decode(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType("decode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64") + fun TimestampFormat(runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format): RuntimeType { val timestampFormat = when (format) { TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" @@ -76,6 +81,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na ) } + fun ProtocolTestHelper(runtimeConfig: RuntimeConfig, func: String): RuntimeType = + RuntimeType( + func, RustDependency.ProtocolTestHelpers(runtimeConfig), "protocol_test_helpers" + ) + fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = RustDependency.Http, namespace = "http") val HttpRequestBuilder = Http("request::Builder") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt index 38d0ba041..bba2fa0d6 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.smithy.generators +import software.amazon.smithy.rust.codegen.lang.Compile +import software.amazon.smithy.rust.codegen.lang.Dev import software.amazon.smithy.rust.codegen.lang.RustDependency import software.amazon.smithy.rust.codegen.smithy.RustSettings import software.amazon.smithy.utils.CodeWriter @@ -20,9 +22,18 @@ class CargoTomlGenerator(private val settings: RustSettings, private val writer: writer.insertTrailingNewline() - if (dependencies.isNotEmpty()) { + val compileDependencies = dependencies.filter { it.scope == Compile } + val devDependencies = dependencies.filter { it.scope == Dev } + if (compileDependencies.isNotEmpty()) { writer.write("[dependencies]") - dependencies.forEach { + compileDependencies.forEach { + writer.write(it.toString()) + } + } + + if (devDependencies.isNotEmpty()) { + writer.write("[dev-dependencies]") + devDependencies.forEach { writer.write(it.toString()) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt index c10c79864..3a6da590a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rustBlock @@ -22,7 +23,8 @@ data class ProtocolConfig( val writer: RustWriter, val serviceShape: ServiceShape, val operationShape: OperationShape, - val inputShape: StructureShape + val inputShape: StructureShape, + val protocol: ShapeId ) interface ProtocolGeneratorFactory { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt new file mode 100644 index 000000000..6ddb1f64a --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -0,0 +1,66 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase +import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.util.dq + +class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { + fun render() { + with(protocolConfig) { + operationShape.getTrait(HttpRequestTestsTrait::class.java).map { + renderHttpRequestTests(it) + } + } + } + + private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) { + with(protocolConfig) { + writer.write("#[cfg(test)]") + val operationName = symbolProvider.toSymbol(operationShape).name + val testModuleName = "${operationName.toSnakeCase()}_request_test" + writer.withModule(testModuleName) { + httpRequestTestsTrait.testCases.filter { it.protocol == protocol }.forEach { testCase -> + renderHttpRequestTestCase(testCase, this) + } + } + } + } + + private val instantiator = with(protocolConfig) { + Instantiator(symbolProvider, model, runtimeConfig) + } + + private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) { + httpRequestTestCase.documentation.map { + testModuleWriter.setNewlinePrefix("/// ").write(it).setNewlinePrefix("") + } + testModuleWriter.write("#[test]") + testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { + writeInline("let input =") + instantiator.render(httpRequestTestCase.params, protocolConfig.inputShape, this) + write(";") + write("let http_request = input.build_http_request().body(()).unwrap();") + with(httpRequestTestCase) { + write( + """ + assert_eq!(http_request.method(), ${method.dq()}); + assert_eq!(http_request.uri().path(), ${uri.dq()}); + """ + ) + withBlock("let expected_query_params = vec![", "];") { + write(queryParams.joinToString(",") { it.dq() }) + } + write( + "\$T(&http_request, expected_query_params.as_slice()).unwrap();", + RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_query_string") + ) + // TODO: assert on the body contents + write("/* BODY:\n ${body.orElse("[ No Body ]")} */") + } + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt new file mode 100644 index 000000000..c541487ea --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -0,0 +1,197 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ArrayNode +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.NumberNode +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.node.StringNode +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.util.dq + +/** + * Instantiator generates code to instantiate a given Shape given a `Node` representing the value + * + * This is primarily used during Protocol test generation + */ +class Instantiator( + private val symbolProvider: SymbolProvider, + private val model: Model, + private val runtimeConfig: RuntimeConfig +) { + + fun render(arg: Node, shape: Shape, writer: RustWriter) { + when (shape) { + // Compound Shapes + is StructureShape -> renderStructure(writer, shape, arg as ObjectNode) + is UnionShape -> renderUnion(writer, shape, arg as ObjectNode) + + // Collections + is ListShape -> renderList(writer, shape, arg as ArrayNode) + is MapShape -> renderMap(writer, shape, arg as ObjectNode) + + // Wrapped Shapes + is TimestampShape -> writer.write( + "\$T::from_epoch_seconds(${(arg as NumberNode).value})", + RuntimeType.Instant(runtimeConfig) + ) + + is BlobShape -> writer.write( + "\$T::new(\$T(${(arg as StringNode).value.dq()}).unwrap())", + RuntimeType.Blob(runtimeConfig), + RuntimeType.Base64Decode(runtimeConfig) + ) + + // Simple Shapes + is StringShape -> renderString(writer, shape, arg as StringNode) + is NumberShape -> writer.write(arg.asNumberNode().get()) + is BooleanShape -> writer.write(arg.asBooleanNode().get().toString()) + else -> writer.write("todo!() /* $shape $arg */") + } + } + + /** + * ```rust + * { + * let mut ret = HashMap::new(); + * ret.insert("k", ...); + * ret.insert("k2", ...); + * ret + * } + */ + private fun renderMap( + writer: RustWriter, + shape: MapShape, + data: ObjectNode + ) { + writer.rustBlock("") { + if (data.members.isNotEmpty()) { + write("let mut ret = \$T::new();", RuntimeType.HashMap) + val valueShape = shape.value.let { model.expectShape(it.target) } + data.members.forEach { (k, v) -> + withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") { + render(v, valueShape, this) + } + } + write("ret") + } else { + writer.write("\$T::new()", RuntimeType.HashMap) + } + } + } + + /** + * ```rust + * MyUnion::Variant(...) + * ``` + */ + private fun renderUnion( + writer: RustWriter, + shape: UnionShape, + data: ObjectNode + ) { + val unionSymbol = symbolProvider.toSymbol(shape) + check(data.members.size == 1) + val variant = data.members.iterator().next() + val memberName = variant.key.value + val member = shape.getMember(memberName).get() + .let { model.expectShape(it.target) } + // TODO: refactor this detail into UnionGenerator + writer.write("\$T::${memberName.toPascalCase()}", unionSymbol) + // unions should specify exactly one member + writer.withBlock("(", ")") { + render(variant.value, member, this) + } + } + + /** + * ```rust + * vec![..., ..., ...] + * ``` + */ + private fun renderList( + writer: RustWriter, + shape: ListShape, + data: ArrayNode + ) { + val member = model.expectShape(shape.member.target) + val memberSymbol = symbolProvider.toSymbol(shape.member) + writer.withBlock("vec![", "]") { + data.elements.forEach { + if (it.isNullNode) { + write("None") + } else { + withBlock("Some(", ")", conditional = memberSymbol.isOptional()) { + render(it, member, this) + } + } + write(",") + } + } + } + + private fun renderString( + writer: RustWriter, + shape: StringShape, + arg: StringNode + ) { + val enumTrait = shape.getTrait(EnumTrait::class.java).orElse(null) + val data = arg.value.dq() + if (enumTrait == null) { + writer.write("$data.to_string()") + } else { + val enumSymbol = symbolProvider.toSymbol(shape) + writer.write("\$T::from($data)", enumSymbol) + } + } + + /** + * ```rust + * MyStruct::builder().field_1("hello").field_2(5).build() + * ``` + */ + private fun renderStructure( + writer: RustWriter, + shape: StructureShape, + data: ObjectNode + ) { + writer.write("\$T::builder()", symbolProvider.toSymbol(shape)) + data.members.forEach { (key, value) -> + val (memberShape, targetShape) = getMember(shape, key) + val func = symbolProvider.toMemberName(memberShape) + if (!value.isNullNode) { + writer.withBlock(".$func(", ")") { + render(value, targetShape, this) + } + } + } + writer.write(".build()") + if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + writer.write(".unwrap()") + } + } + + private fun getMember(shape: StructureShape, key: StringNode): Pair { + val memberShape = shape.getMember(key.value) + .orElseThrow { IllegalArgumentException("$shape did not have member ${key.value}") } + return memberShape to model.expectShape(memberShape.target) + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt index 68fe443c9..60702bbf1 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt @@ -37,7 +37,7 @@ class ServiceGenerator( ) private val protocols: MutableMap = ServiceIndex(model).getProtocols(serviceShape) - private val matchingProtocols = protocols.keys.mapNotNull { supportedProtocols[it] } + private val matchingProtocols = protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } } init { if (matchingProtocols.isEmpty()) { @@ -47,12 +47,15 @@ class ServiceGenerator( fun render() { val operations = index.getContainedOperations(serviceShape) + val (protocol, generator) = matchingProtocols.first() + // TODO: refactor so that we don't need to re-instantiate the protocol for every operation operations.map { operation -> writers.useShapeWriter(operation) { writer -> // transform ensures that all models have input shapes val input = operation.input.get().let { model.expectShape(it, StructureShape::class.java) } - val config = ProtocolConfig(model, symbolProvider, runtimeConfig, writer, serviceShape, operation, input) - matchingProtocols.first().build(config).render() + val config = ProtocolConfig(model, symbolProvider, runtimeConfig, writer, serviceShape, operation, input, protocol) + generator.build(config).render() + HttpProtocolTestGenerator(config).render() } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt index b7297fa09..7ff158528 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt @@ -34,7 +34,9 @@ class StructureGenerator( ) { private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) - private val builderSymbol = RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}") + private val builderSymbol = + RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}") + fun render() { renderStructure() val errorTrait = shape.getTrait(ErrorTrait::class.java) @@ -50,6 +52,16 @@ class StructureGenerator( } } + companion object { + fun fallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = structureShape + .allMembers + .values.map { symbolProvider.toSymbol(it) }.any { + // If any members are not optional && we can't use a default, we need to + // generate a fallible builder + !it.isOptional() && !it.canUseDefault() + } + } + private fun renderStructure() { val symbol = symbolProvider.toSymbol(shape) // TODO(maybe): Pull derive info from the symbol so that the symbol provider can alter things as necessary; 4h @@ -115,19 +127,14 @@ class StructureGenerator( } } - val fallible = members.map { symbolProvider.toSymbol(it) }.any { - // If any members are not optional && we can't use a default, we need to - // generate a fallible builder - !it.isOptional() && !it.canUseDefault() - } - - val returnType = when (fallible) { + val fallibleBuilder = fallibleBuilder(shape, symbolProvider) + val returnType = when (fallibleBuilder) { true -> "Result<\$T, String>" false -> "\$T" } writer.rustBlock("pub fn build(self) -> $returnType", structureSymbol) { - withBlock("Ok(", ")", conditional = fallible) { + withBlock("Ok(", ")", conditional = fallibleBuilder) { rustBlock("\$T", structureSymbol) { members.forEach { member -> val memberName = symbolProvider.toMemberName(member) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt index 3dcf0693f..3eebe6bad 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt @@ -11,7 +11,7 @@ import software.amazon.smithy.rust.codegen.smithy.letIf class CommandFailed(output: String) : Exception("Command Failed\n$output") -fun String.runCommand(workdir: Path? = null): String? { +fun String.runCommand(workdir: Path? = null): String { val parts = this.split("\\s".toRegex()) val proc = ProcessBuilder(*parts.toTypedArray()) .redirectOutput(ProcessBuilder.Redirect.PIPE) @@ -22,10 +22,11 @@ fun String.runCommand(workdir: Path? = null): String? { .start() proc.waitFor(60, TimeUnit.MINUTES) - if (proc.exitValue() != 0) { - val stdErr = proc.errorStream.bufferedReader().readText() - val stdOut = proc.inputStream.bufferedReader().readText() - throw CommandFailed("Command Failed\n$stdErr\n$stdOut") + val stdErr = proc.errorStream.bufferedReader().readText() + val stdOut = proc.inputStream.bufferedReader().readText() + val output = "$stdErr\n$stdOut" + return when (proc.exitValue()) { + 0 -> output + else -> throw CommandFailed("Command Failed\n$output") } - return proc.inputStream.bufferedReader().readText() } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt new file mode 100644 index 000000000..9dd894110 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -0,0 +1,9 @@ +package software.amazon.smithy.rust.codegen.util + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId + +inline fun Model.lookup(shapeId: String): T { + return this.expectShape(ShapeId.from(shapeId), T::class.java) +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt new file mode 100644 index 000000000..3cc156839 --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt @@ -0,0 +1,151 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.util.CommandFailed +import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.testutil.asSmithy +import software.amazon.smithy.rust.testutil.shouldCompile +import software.amazon.smithy.rust.testutil.testSymbolProvider + +class HttpProtocolTestGeneratorTest { + val baseModel = """ + namespace com.example + + use aws.protocols#restJson1 + use smithy.test#httpRequestTests + + @restJson1 + service HelloService { + operations: [SayHello], + version: "1" + } + + @http(method: "POST", uri: "/") + @httpRequestTests([ + { + id: "say_hello", + protocol: restJson1, + params: { + "greeting": "Hi", + "name": "Teddy", + "query": "Hello there" + }, + method: "POST", + uri: "/", + queryParams: [ + "Hi=Hello%20there" + ], + headers: { + "X-Greeting": "Hi", + }, + body: "{\"name\": \"Teddy\"}", + bodyMediaType: "application/json" + } + ]) + operation SayHello { + input: SayHelloInput + } + + structure SayHelloInput { + @httpHeader("X-Greeting") + greeting: String, + + @httpQuery("Hi") + query: String, + + name: String + } + """.asSmithy() + private val model = OperationNormalizer(testSymbolProvider(baseModel)).addOperationInputs(baseModel) + private val symbolProvider = testSymbolProvider(model) + private val runtimeConfig = TestRuntimeConfig + + private fun fakeInput(writer: RustWriter, body: String) { + StructureGenerator(model, symbolProvider, writer, model.lookup("com.example#SayHelloInput")).render() + writer.rustBlock("impl SayHelloInput") { + rustBlock("pub fn build_http_request(&self) -> \$T", RuntimeType.HttpRequestBuilder) { + write("\$T::new()", RuntimeType.HttpRequestBuilder) + write(body) + } + } + } + + @Test + fun `passing e2e protocol request test`() { + val writer = RustWriter.forModule("lib") + + // Hard coded implementation for this 1 test + writer.withModule("operation") { + fakeInput( + this, + """ + .uri("/?Hi=Hello%20there") + .header("X-Greeting", "Hi") + .method("POST") + """ + ) + val protocolConfig = ProtocolConfig( + model, + symbolProvider, + runtimeConfig, + this, + model.lookup("com.example#HelloService"), + model.lookup("com.example#SayHello"), + model.lookup("com.example#SayHelloInput"), + RestJson1Trait.ID + ) + HttpProtocolTestGenerator( + protocolConfig + ).render() + } + + val testOutput = writer.shouldCompile() + // Verify the test actually ran + testOutput shouldContain "test_say_hello ... ok" + } + + @Test + fun `failing e2e protocol test`() { + val writer = RustWriter.forModule("lib") + + // Hard coded implementation for this 1 test + writer.withModule("operation") { + fakeInput( + this, + """ + .uri("/?Hi=INCORRECT") + .header("X-Greeting", "Hi") + .method("POST") + """ + ) + val protocolConfig = ProtocolConfig( + model, + symbolProvider, + runtimeConfig, + this, + model.lookup("com.example#HelloService"), + model.lookup("com.example#SayHello"), + model.lookup("com.example#SayHelloInput"), + RestJson1Trait.ID + ) + HttpProtocolTestGenerator( + protocolConfig + ).render() + } + + val err = assertThrows { + writer.shouldCompile(expectFailure = true) + } + // Verify the test actually ran + err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "MissingQueryParam" + } +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt new file mode 100644 index 000000000..0b7422b98 --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt @@ -0,0 +1,165 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.testutil.asSmithy +import software.amazon.smithy.rust.testutil.shouldCompile +import software.amazon.smithy.rust.testutil.testSymbolProvider + +class InstantiatorTest { + private val model = """ + namespace com.test + @documentation("this documents the shape") + structure MyStruct { + foo: String, + @documentation("This *is* documentation about the member.") + bar: PrimitiveInteger, + baz: Integer, + ts: Timestamp, + byteValue: Byte + } + + list MyList { + member: String + } + + @sparse + list MySparseList { + member: String + } + + union MyUnion { + stringVariant: String, + numVariant: Integer + } + + structure Inner { + map: NestedMap + } + + + map NestedMap { + key: String, + value: Inner + } + """.asSmithy() + + private val symbolProvider = testSymbolProvider(model) + private val runtimeConfig = TestRuntimeConfig + + // TODO: test of recursive structures when supported + + @Test + fun `generate unions`() { + val union = model.lookup("com.test#MyUnion") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val data = Node.parse("""{ + "stringVariant": "ok!" + }""") + val writer = RustWriter.forModule("model") + UnionGenerator(model, symbolProvider, writer, union).render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, union, this) + } + writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));") + } + } + + @Test + fun `generate struct builders`() { + val structure = model.lookup("com.test#MyStruct") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val data = Node.parse(""" { + "bar": 10, + "foo": "hello" + } + """.trimIndent()) + val writer = RustWriter.forModule("model") + val structureGenerator = StructureGenerator(model, symbolProvider, writer, structure) + structureGenerator.render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, structure, this) + } + writer.write("assert_eq!(result.bar, 10);") + writer.write("assert_eq!(result.foo.unwrap(), \"hello\");") + } + writer.shouldCompile() + } + + @Test + fun `generate lists`() { + val data = Node.parse(""" [ + "bar", + "foo" + ] + """) + val writer = RustWriter.forModule("lib") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#MyList"), writer) + } + writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""") + } + writer.shouldCompile() + } + + @Test + fun `generate sparse lists`() { + val data = Node.parse(""" [ + "bar", + "foo", + null + ] + """) + val writer = RustWriter.forModule("lib") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#MySparseList"), writer) + } + writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""") + } + writer.shouldCompile() + } + + @Test + fun `generate maps of maps`() { + val data = Node.parse("""{ + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } + } + """) + val writer = RustWriter.forModule("model") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val structureGenerator = StructureGenerator(model, symbolProvider, writer, model.lookup("com.test#Inner")) + structureGenerator.render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#NestedMap"), writer) + } + writer.write(""" + assert_eq!(result.len(), 3); + assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); + assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); + assert_eq!(result.get("k3").unwrap().map, None); + """) + } + writer.shouldCompile(strict = true) + } +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt index 91907000d..80361ca2a 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt @@ -18,18 +18,30 @@ fun String.shouldParseAsRust() { "rustfmt ${tempFile.absolutePath}".runCommand() } -fun RustWriter.shouldCompile(main: String = "") { +fun RustWriter.shouldCompile(main: String = "", strict: Boolean = false, expectFailure: Boolean = false): String { val deps = this.dependencies.map { RustDependency.fromSymbolDependency(it) } try { - this.toString().shouldCompile(deps.toSet(), module = this.namespace.split("::")[1], main = main) + val output = this.toString() + .shouldCompile(deps.toSet(), module = this.namespace.split("::")[1], main = main, strict = strict) + if (expectFailure) { + println(this.toString()) + } + return output } catch (e: CommandFailed) { // When the test fails, print the code for convenience - println(this.toString()) + if (!expectFailure) { + println(this.toString()) + } throw e } } -fun String.shouldCompile(deps: Set, module: String? = null, main: String = "") { +fun String.shouldCompile( + deps: Set, + module: String? = null, + main: String = "", + strict: Boolean = false +): String { this.shouldParseAsRust() val tempDir = createTempDir() // TODO: unify this with CargoTomlGenerator @@ -39,7 +51,7 @@ fun String.shouldCompile(deps: Set, module: String? = null, main version = "0.0.1" authors = ["rcoh@amazon.com"] edition = "2018" - + [dependencies] ${deps.joinToString("\n") { it.toString() }} """.trimIndent() @@ -48,22 +60,28 @@ fun String.shouldCompile(deps: Set, module: String? = null, main val mainRs = tempDir.resolve("src/main.rs") val testModule = tempDir.resolve("src/$module.rs") testModule.writeText(this) - testModule.appendText(""" + testModule.appendText( + """ #[test] fn test() { $main - } - """.trimIndent()) - mainRs.appendText(""" + } + """.trimIndent() + ) + mainRs.appendText( + """ pub mod $module; use crate::$module::*; fn main() { } - """.trimIndent()) + """.trimIndent() + ) "cargo check".runCommand(tempDir.toPath()) - if (main != "") { - "cargo test".runCommand(tempDir.toPath()) + val testOutput = "cargo test".runCommand(tempDir.toPath()) + if (strict) { + "cargo clippy -- -D warnings".runCommand(tempDir.toPath()) } + return testOutput } fun String.shouldCompile() { diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt index 8274a3e15..695804753 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt @@ -21,5 +21,5 @@ fun testSymbolProvider(model: Model): SymbolProvider = SymbolVisitor(model, "tes private val SmithyVersion = "1.0" fun String.asSmithy(sourceLocation: String? = null): Model { val processed = letIf(!this.startsWith("\$version")) { "\$version: ${SmithyVersion.dq()}\n$it" } - return Model.assembler().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble().unwrap() + return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble().unwrap() } diff --git a/rust-runtime/smithy-types/src/lib.rs b/rust-runtime/smithy-types/src/lib.rs index 491f255b6..158a9669c 100644 --- a/rust-runtime/smithy-types/src/lib.rs +++ b/rust-runtime/smithy-types/src/lib.rs @@ -85,6 +85,12 @@ pub struct Blob { inner: Vec, } +impl Blob { + pub fn new(inp: Vec) -> Self { + Blob { inner: inp } + } +} + #[cfg(test)] mod test { use crate::instant::Format; -- GitLab