diff --git a/codegen-test/build.gradle.kts b/codegen-test/build.gradle.kts index 8e6a6a69ee7d3373a3fa4e405ada53dd74b0a658..7efd698c795f1b5ab613972b2bd400e92fac3ad0 100644 --- a/codegen-test/build.gradle.kts +++ b/codegen-test/build.gradle.kts @@ -93,6 +93,14 @@ tasks.register("cargoCheck") { dependsOn("build") } +tasks.register("cargoTest") { + workingDir("build/smithyprojections/codegen-test/") + // disallow warnings + environment("RUSTFLAGS", "-D warnings") + commandLine("cargo", "test") + dependsOn("build") +} + tasks.register("cargoClippy") { workingDir("build/smithyprojections/codegen-test/") // disallow warnings @@ -101,7 +109,7 @@ tasks.register("cargoClippy") { dependsOn("build") } -tasks["test"].finalizedBy("cargoCheck", "cargoClippy") +tasks["test"].finalizedBy("cargoCheck", "cargoClippy", "cargoTest") tasks["clean"].doFirst { delete("smithy-build.json") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt index cce8c3b3ab85934de36582cc767479c44ae89161..7763ed679c3c2ba850217140890146cd32a0f863 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustDependency.kt @@ -9,13 +9,18 @@ import software.amazon.smithy.codegen.core.SymbolDependency import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +sealed class DependencyScope +object Dev : DependencyScope() +object Compile : DependencyScope() + sealed class DependencyLocation data class CratesIo(val version: String) : DependencyLocation() data class Local(val path: String? = null) : DependencyLocation() data class RustDependency( val name: String, - val location: DependencyLocation + val location: DependencyLocation, + val scope: DependencyScope = Compile ) : SymbolDependencyContainer { override fun getDependencies(): List { return listOf( @@ -44,6 +49,10 @@ data class RustDependency( "${runtimeConfig.cratePrefix}-http", Local(runtimeConfig.relativePath) ) + fun ProtocolTestHelpers(runtimeConfig: RuntimeConfig) = RustDependency( + "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = Dev + ) + private val PropKey = "rustdep" fun fromSymbolDependency(symbolDependency: SymbolDependency) = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt index 08b5b137607fcc9d3134d124d00dc7d8dde0897f..86ae3566dbf445413c302be710c307e6dc3bfdb4 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt @@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.runCommand private val PublicModules = listOf("error", "operation", "model") @@ -71,7 +72,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { LibRsGenerator(modules, it).render() } writers.flushWriters() - "cargo fmt".runCommand(fileManifest.baseDir) + try { + "cargo fmt".runCommand(fileManifest.baseDir) + } catch (_: CommandFailed) { + logger.warning("Generated output did not parse") + } } override fun getDefault(shape: Shape?) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index 8703add78ce1c2b3650bee77b091e6b968fba608..c3df3fd497ca3c2940e15868804ffee68eef48c3 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -40,12 +40,14 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na // TODO: refactor to be RuntimeTypeProvider a la Symbol provider that packages the `RuntimeConfig` state. companion object { + // val Blob = RuntimeType("Blob", RustDependency.IO_CORE, "blob") val From = RuntimeType("From", dependency = null, namespace = "std::convert") val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert") fun StdFmt(member: String) = RuntimeType("fmt::$member", dependency = null, namespace = "std") val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") val HashSet = RuntimeType("HashSet", dependency = null, namespace = "std::collections") + val HashMap = RuntimeType("HashMap", dependency = null, namespace = "std::collections") fun Instant(runtimeConfig: RuntimeConfig) = RuntimeType("Instant", RustDependency.SmithyTypes(runtimeConfig), "${runtimeConfig.cratePrefix}_types") @@ -62,6 +64,9 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na fun Base64Encode(runtimeConfig: RuntimeConfig): RuntimeType = RuntimeType("encode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64") + fun Base64Decode(runtimeConfig: RuntimeConfig): RuntimeType = + RuntimeType("decode", RustDependency.SmithyHttp(runtimeConfig), "${runtimeConfig.cratePrefix}_http::base64") + fun TimestampFormat(runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format): RuntimeType { val timestampFormat = when (format) { TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" @@ -76,6 +81,11 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na ) } + fun ProtocolTestHelper(runtimeConfig: RuntimeConfig, func: String): RuntimeType = + RuntimeType( + func, RustDependency.ProtocolTestHelpers(runtimeConfig), "protocol_test_helpers" + ) + fun Http(path: String): RuntimeType = RuntimeType(name = path, dependency = RustDependency.Http, namespace = "http") val HttpRequestBuilder = Http("request::Builder") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt index 38d0ba041500796dc570700af379d778eaf691e2..bba2fa0d663997594862951fc0a93be82d8271f9 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/CargoTomlGenerator.kt @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.smithy.generators +import software.amazon.smithy.rust.codegen.lang.Compile +import software.amazon.smithy.rust.codegen.lang.Dev import software.amazon.smithy.rust.codegen.lang.RustDependency import software.amazon.smithy.rust.codegen.smithy.RustSettings import software.amazon.smithy.utils.CodeWriter @@ -20,9 +22,18 @@ class CargoTomlGenerator(private val settings: RustSettings, private val writer: writer.insertTrailingNewline() - if (dependencies.isNotEmpty()) { + val compileDependencies = dependencies.filter { it.scope == Compile } + val devDependencies = dependencies.filter { it.scope == Dev } + if (compileDependencies.isNotEmpty()) { writer.write("[dependencies]") - dependencies.forEach { + compileDependencies.forEach { + writer.write(it.toString()) + } + } + + if (devDependencies.isNotEmpty()) { + writer.write("[dev-dependencies]") + devDependencies.forEach { writer.write(it.toString()) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt index c10c79864ef0b775e823c7f5a198555ce0f46836..3a6da590a87a2eb4aca3acfd76a95803f7236a9e 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.rustBlock @@ -22,7 +23,8 @@ data class ProtocolConfig( val writer: RustWriter, val serviceShape: ServiceShape, val operationShape: OperationShape, - val inputShape: StructureShape + val inputShape: StructureShape, + val protocol: ShapeId ) interface ProtocolGeneratorFactory { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt new file mode 100644 index 0000000000000000000000000000000000000000..6ddb1f64a2096e306f28a55124df71ae79d69883 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -0,0 +1,66 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase +import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.util.dq + +class HttpProtocolTestGenerator(private val protocolConfig: ProtocolConfig) { + fun render() { + with(protocolConfig) { + operationShape.getTrait(HttpRequestTestsTrait::class.java).map { + renderHttpRequestTests(it) + } + } + } + + private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) { + with(protocolConfig) { + writer.write("#[cfg(test)]") + val operationName = symbolProvider.toSymbol(operationShape).name + val testModuleName = "${operationName.toSnakeCase()}_request_test" + writer.withModule(testModuleName) { + httpRequestTestsTrait.testCases.filter { it.protocol == protocol }.forEach { testCase -> + renderHttpRequestTestCase(testCase, this) + } + } + } + } + + private val instantiator = with(protocolConfig) { + Instantiator(symbolProvider, model, runtimeConfig) + } + + private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) { + httpRequestTestCase.documentation.map { + testModuleWriter.setNewlinePrefix("/// ").write(it).setNewlinePrefix("") + } + testModuleWriter.write("#[test]") + testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { + writeInline("let input =") + instantiator.render(httpRequestTestCase.params, protocolConfig.inputShape, this) + write(";") + write("let http_request = input.build_http_request().body(()).unwrap();") + with(httpRequestTestCase) { + write( + """ + assert_eq!(http_request.method(), ${method.dq()}); + assert_eq!(http_request.uri().path(), ${uri.dq()}); + """ + ) + withBlock("let expected_query_params = vec![", "];") { + write(queryParams.joinToString(",") { it.dq() }) + } + write( + "\$T(&http_request, expected_query_params.as_slice()).unwrap();", + RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_query_string") + ) + // TODO: assert on the body contents + write("/* BODY:\n ${body.orElse("[ No Body ]")} */") + } + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt new file mode 100644 index 0000000000000000000000000000000000000000..c541487eae901e491a92fd9576ed9e0cc357a3d8 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -0,0 +1,197 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ArrayNode +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.NumberNode +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.node.StringNode +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.util.dq + +/** + * Instantiator generates code to instantiate a given Shape given a `Node` representing the value + * + * This is primarily used during Protocol test generation + */ +class Instantiator( + private val symbolProvider: SymbolProvider, + private val model: Model, + private val runtimeConfig: RuntimeConfig +) { + + fun render(arg: Node, shape: Shape, writer: RustWriter) { + when (shape) { + // Compound Shapes + is StructureShape -> renderStructure(writer, shape, arg as ObjectNode) + is UnionShape -> renderUnion(writer, shape, arg as ObjectNode) + + // Collections + is ListShape -> renderList(writer, shape, arg as ArrayNode) + is MapShape -> renderMap(writer, shape, arg as ObjectNode) + + // Wrapped Shapes + is TimestampShape -> writer.write( + "\$T::from_epoch_seconds(${(arg as NumberNode).value})", + RuntimeType.Instant(runtimeConfig) + ) + + is BlobShape -> writer.write( + "\$T::new(\$T(${(arg as StringNode).value.dq()}).unwrap())", + RuntimeType.Blob(runtimeConfig), + RuntimeType.Base64Decode(runtimeConfig) + ) + + // Simple Shapes + is StringShape -> renderString(writer, shape, arg as StringNode) + is NumberShape -> writer.write(arg.asNumberNode().get()) + is BooleanShape -> writer.write(arg.asBooleanNode().get().toString()) + else -> writer.write("todo!() /* $shape $arg */") + } + } + + /** + * ```rust + * { + * let mut ret = HashMap::new(); + * ret.insert("k", ...); + * ret.insert("k2", ...); + * ret + * } + */ + private fun renderMap( + writer: RustWriter, + shape: MapShape, + data: ObjectNode + ) { + writer.rustBlock("") { + if (data.members.isNotEmpty()) { + write("let mut ret = \$T::new();", RuntimeType.HashMap) + val valueShape = shape.value.let { model.expectShape(it.target) } + data.members.forEach { (k, v) -> + withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") { + render(v, valueShape, this) + } + } + write("ret") + } else { + writer.write("\$T::new()", RuntimeType.HashMap) + } + } + } + + /** + * ```rust + * MyUnion::Variant(...) + * ``` + */ + private fun renderUnion( + writer: RustWriter, + shape: UnionShape, + data: ObjectNode + ) { + val unionSymbol = symbolProvider.toSymbol(shape) + check(data.members.size == 1) + val variant = data.members.iterator().next() + val memberName = variant.key.value + val member = shape.getMember(memberName).get() + .let { model.expectShape(it.target) } + // TODO: refactor this detail into UnionGenerator + writer.write("\$T::${memberName.toPascalCase()}", unionSymbol) + // unions should specify exactly one member + writer.withBlock("(", ")") { + render(variant.value, member, this) + } + } + + /** + * ```rust + * vec![..., ..., ...] + * ``` + */ + private fun renderList( + writer: RustWriter, + shape: ListShape, + data: ArrayNode + ) { + val member = model.expectShape(shape.member.target) + val memberSymbol = symbolProvider.toSymbol(shape.member) + writer.withBlock("vec![", "]") { + data.elements.forEach { + if (it.isNullNode) { + write("None") + } else { + withBlock("Some(", ")", conditional = memberSymbol.isOptional()) { + render(it, member, this) + } + } + write(",") + } + } + } + + private fun renderString( + writer: RustWriter, + shape: StringShape, + arg: StringNode + ) { + val enumTrait = shape.getTrait(EnumTrait::class.java).orElse(null) + val data = arg.value.dq() + if (enumTrait == null) { + writer.write("$data.to_string()") + } else { + val enumSymbol = symbolProvider.toSymbol(shape) + writer.write("\$T::from($data)", enumSymbol) + } + } + + /** + * ```rust + * MyStruct::builder().field_1("hello").field_2(5).build() + * ``` + */ + private fun renderStructure( + writer: RustWriter, + shape: StructureShape, + data: ObjectNode + ) { + writer.write("\$T::builder()", symbolProvider.toSymbol(shape)) + data.members.forEach { (key, value) -> + val (memberShape, targetShape) = getMember(shape, key) + val func = symbolProvider.toMemberName(memberShape) + if (!value.isNullNode) { + writer.withBlock(".$func(", ")") { + render(value, targetShape, this) + } + } + } + writer.write(".build()") + if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + writer.write(".unwrap()") + } + } + + private fun getMember(shape: StructureShape, key: StringNode): Pair { + val memberShape = shape.getMember(key.value) + .orElseThrow { IllegalArgumentException("$shape did not have member ${key.value}") } + return memberShape to model.expectShape(memberShape.target) + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt index 68fe443c954b2b5d6971eeb3480c61df8451aab4..60702bbf1b8b400ca797170c3ba42be65a67bf30 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt @@ -37,7 +37,7 @@ class ServiceGenerator( ) private val protocols: MutableMap = ServiceIndex(model).getProtocols(serviceShape) - private val matchingProtocols = protocols.keys.mapNotNull { supportedProtocols[it] } + private val matchingProtocols = protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } } init { if (matchingProtocols.isEmpty()) { @@ -47,12 +47,15 @@ class ServiceGenerator( fun render() { val operations = index.getContainedOperations(serviceShape) + val (protocol, generator) = matchingProtocols.first() + // TODO: refactor so that we don't need to re-instantiate the protocol for every operation operations.map { operation -> writers.useShapeWriter(operation) { writer -> // transform ensures that all models have input shapes val input = operation.input.get().let { model.expectShape(it, StructureShape::class.java) } - val config = ProtocolConfig(model, symbolProvider, runtimeConfig, writer, serviceShape, operation, input) - matchingProtocols.first().build(config).render() + val config = ProtocolConfig(model, symbolProvider, runtimeConfig, writer, serviceShape, operation, input, protocol) + generator.build(config).render() + HttpProtocolTestGenerator(config).render() } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt index b7297fa090671938986cfb5a5e6652e29aeed034..7ff1585283d508a9a702b348f34923ad5e053d4b 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt @@ -34,7 +34,9 @@ class StructureGenerator( ) { private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) - private val builderSymbol = RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}") + private val builderSymbol = + RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}") + fun render() { renderStructure() val errorTrait = shape.getTrait(ErrorTrait::class.java) @@ -50,6 +52,16 @@ class StructureGenerator( } } + companion object { + fun fallibleBuilder(structureShape: StructureShape, symbolProvider: SymbolProvider): Boolean = structureShape + .allMembers + .values.map { symbolProvider.toSymbol(it) }.any { + // If any members are not optional && we can't use a default, we need to + // generate a fallible builder + !it.isOptional() && !it.canUseDefault() + } + } + private fun renderStructure() { val symbol = symbolProvider.toSymbol(shape) // TODO(maybe): Pull derive info from the symbol so that the symbol provider can alter things as necessary; 4h @@ -115,19 +127,14 @@ class StructureGenerator( } } - val fallible = members.map { symbolProvider.toSymbol(it) }.any { - // If any members are not optional && we can't use a default, we need to - // generate a fallible builder - !it.isOptional() && !it.canUseDefault() - } - - val returnType = when (fallible) { + val fallibleBuilder = fallibleBuilder(shape, symbolProvider) + val returnType = when (fallibleBuilder) { true -> "Result<\$T, String>" false -> "\$T" } writer.rustBlock("pub fn build(self) -> $returnType", structureSymbol) { - withBlock("Ok(", ")", conditional = fallible) { + withBlock("Ok(", ")", conditional = fallibleBuilder) { rustBlock("\$T", structureSymbol) { members.forEach { member -> val memberName = symbolProvider.toMemberName(member) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt index 3dcf0693f3a609461dac7283f6e3a59324955c6f..3eebe6bada95a99ef622b76bb44cef613b99d1d5 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt @@ -11,7 +11,7 @@ import software.amazon.smithy.rust.codegen.smithy.letIf class CommandFailed(output: String) : Exception("Command Failed\n$output") -fun String.runCommand(workdir: Path? = null): String? { +fun String.runCommand(workdir: Path? = null): String { val parts = this.split("\\s".toRegex()) val proc = ProcessBuilder(*parts.toTypedArray()) .redirectOutput(ProcessBuilder.Redirect.PIPE) @@ -22,10 +22,11 @@ fun String.runCommand(workdir: Path? = null): String? { .start() proc.waitFor(60, TimeUnit.MINUTES) - if (proc.exitValue() != 0) { - val stdErr = proc.errorStream.bufferedReader().readText() - val stdOut = proc.inputStream.bufferedReader().readText() - throw CommandFailed("Command Failed\n$stdErr\n$stdOut") + val stdErr = proc.errorStream.bufferedReader().readText() + val stdOut = proc.inputStream.bufferedReader().readText() + val output = "$stdErr\n$stdOut" + return when (proc.exitValue()) { + 0 -> output + else -> throw CommandFailed("Command Failed\n$output") } - return proc.inputStream.bufferedReader().readText() } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt new file mode 100644 index 0000000000000000000000000000000000000000..9dd8941102e19e53b7fdbfd30af41d154331ff25 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -0,0 +1,9 @@ +package software.amazon.smithy.rust.codegen.util + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId + +inline fun Model.lookup(shapeId: String): T { + return this.expectShape(ShapeId.from(shapeId), T::class.java) +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..3cc156839dea1411cd8d4d32bc5b51d429d1996b --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt @@ -0,0 +1,151 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.util.CommandFailed +import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.testutil.asSmithy +import software.amazon.smithy.rust.testutil.shouldCompile +import software.amazon.smithy.rust.testutil.testSymbolProvider + +class HttpProtocolTestGeneratorTest { + val baseModel = """ + namespace com.example + + use aws.protocols#restJson1 + use smithy.test#httpRequestTests + + @restJson1 + service HelloService { + operations: [SayHello], + version: "1" + } + + @http(method: "POST", uri: "/") + @httpRequestTests([ + { + id: "say_hello", + protocol: restJson1, + params: { + "greeting": "Hi", + "name": "Teddy", + "query": "Hello there" + }, + method: "POST", + uri: "/", + queryParams: [ + "Hi=Hello%20there" + ], + headers: { + "X-Greeting": "Hi", + }, + body: "{\"name\": \"Teddy\"}", + bodyMediaType: "application/json" + } + ]) + operation SayHello { + input: SayHelloInput + } + + structure SayHelloInput { + @httpHeader("X-Greeting") + greeting: String, + + @httpQuery("Hi") + query: String, + + name: String + } + """.asSmithy() + private val model = OperationNormalizer(testSymbolProvider(baseModel)).addOperationInputs(baseModel) + private val symbolProvider = testSymbolProvider(model) + private val runtimeConfig = TestRuntimeConfig + + private fun fakeInput(writer: RustWriter, body: String) { + StructureGenerator(model, symbolProvider, writer, model.lookup("com.example#SayHelloInput")).render() + writer.rustBlock("impl SayHelloInput") { + rustBlock("pub fn build_http_request(&self) -> \$T", RuntimeType.HttpRequestBuilder) { + write("\$T::new()", RuntimeType.HttpRequestBuilder) + write(body) + } + } + } + + @Test + fun `passing e2e protocol request test`() { + val writer = RustWriter.forModule("lib") + + // Hard coded implementation for this 1 test + writer.withModule("operation") { + fakeInput( + this, + """ + .uri("/?Hi=Hello%20there") + .header("X-Greeting", "Hi") + .method("POST") + """ + ) + val protocolConfig = ProtocolConfig( + model, + symbolProvider, + runtimeConfig, + this, + model.lookup("com.example#HelloService"), + model.lookup("com.example#SayHello"), + model.lookup("com.example#SayHelloInput"), + RestJson1Trait.ID + ) + HttpProtocolTestGenerator( + protocolConfig + ).render() + } + + val testOutput = writer.shouldCompile() + // Verify the test actually ran + testOutput shouldContain "test_say_hello ... ok" + } + + @Test + fun `failing e2e protocol test`() { + val writer = RustWriter.forModule("lib") + + // Hard coded implementation for this 1 test + writer.withModule("operation") { + fakeInput( + this, + """ + .uri("/?Hi=INCORRECT") + .header("X-Greeting", "Hi") + .method("POST") + """ + ) + val protocolConfig = ProtocolConfig( + model, + symbolProvider, + runtimeConfig, + this, + model.lookup("com.example#HelloService"), + model.lookup("com.example#SayHello"), + model.lookup("com.example#SayHelloInput"), + RestJson1Trait.ID + ) + HttpProtocolTestGenerator( + protocolConfig + ).render() + } + + val err = assertThrows { + writer.shouldCompile(expectFailure = true) + } + // Verify the test actually ran + err.message shouldContain "test_say_hello ... FAILED" + err.message shouldContain "MissingQueryParam" + } +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..0b7422b98f98692eaf285df9a430a8484efe3fae --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt @@ -0,0 +1,165 @@ +package software.amazon.smithy.rust.codegen.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.util.lookup +import software.amazon.smithy.rust.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.testutil.asSmithy +import software.amazon.smithy.rust.testutil.shouldCompile +import software.amazon.smithy.rust.testutil.testSymbolProvider + +class InstantiatorTest { + private val model = """ + namespace com.test + @documentation("this documents the shape") + structure MyStruct { + foo: String, + @documentation("This *is* documentation about the member.") + bar: PrimitiveInteger, + baz: Integer, + ts: Timestamp, + byteValue: Byte + } + + list MyList { + member: String + } + + @sparse + list MySparseList { + member: String + } + + union MyUnion { + stringVariant: String, + numVariant: Integer + } + + structure Inner { + map: NestedMap + } + + + map NestedMap { + key: String, + value: Inner + } + """.asSmithy() + + private val symbolProvider = testSymbolProvider(model) + private val runtimeConfig = TestRuntimeConfig + + // TODO: test of recursive structures when supported + + @Test + fun `generate unions`() { + val union = model.lookup("com.test#MyUnion") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val data = Node.parse("""{ + "stringVariant": "ok!" + }""") + val writer = RustWriter.forModule("model") + UnionGenerator(model, symbolProvider, writer, union).render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, union, this) + } + writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));") + } + } + + @Test + fun `generate struct builders`() { + val structure = model.lookup("com.test#MyStruct") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val data = Node.parse(""" { + "bar": 10, + "foo": "hello" + } + """.trimIndent()) + val writer = RustWriter.forModule("model") + val structureGenerator = StructureGenerator(model, symbolProvider, writer, structure) + structureGenerator.render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, structure, this) + } + writer.write("assert_eq!(result.bar, 10);") + writer.write("assert_eq!(result.foo.unwrap(), \"hello\");") + } + writer.shouldCompile() + } + + @Test + fun `generate lists`() { + val data = Node.parse(""" [ + "bar", + "foo" + ] + """) + val writer = RustWriter.forModule("lib") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#MyList"), writer) + } + writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""") + } + writer.shouldCompile() + } + + @Test + fun `generate sparse lists`() { + val data = Node.parse(""" [ + "bar", + "foo", + null + ] + """) + val writer = RustWriter.forModule("lib") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#MySparseList"), writer) + } + writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""") + } + writer.shouldCompile() + } + + @Test + fun `generate maps of maps`() { + val data = Node.parse("""{ + "k1": { "map": {} }, + "k2": { "map": { "k3": {} } }, + "k3": { } + } + """) + val writer = RustWriter.forModule("model") + val sut = Instantiator(symbolProvider, model, runtimeConfig) + val structureGenerator = StructureGenerator(model, symbolProvider, writer, model.lookup("com.test#Inner")) + structureGenerator.render() + writer.write("#[test]") + writer.rustBlock("fn inst()") { + writer.withBlock("let result = ", ";") { + sut.render(data, model.lookup("com.test#NestedMap"), writer) + } + writer.write(""" + assert_eq!(result.len(), 3); + assert_eq!(result.get("k1").unwrap().map.as_ref().unwrap().len(), 0); + assert_eq!(result.get("k2").unwrap().map.as_ref().unwrap().len(), 1); + assert_eq!(result.get("k3").unwrap().map, None); + """) + } + writer.shouldCompile(strict = true) + } +} diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt index 91907000dd8bf31d81d0c8869ca7f0cbe613c094..80361ca2a68b61c2b0de9b4149e92f1d1d9c6294 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt @@ -18,18 +18,30 @@ fun String.shouldParseAsRust() { "rustfmt ${tempFile.absolutePath}".runCommand() } -fun RustWriter.shouldCompile(main: String = "") { +fun RustWriter.shouldCompile(main: String = "", strict: Boolean = false, expectFailure: Boolean = false): String { val deps = this.dependencies.map { RustDependency.fromSymbolDependency(it) } try { - this.toString().shouldCompile(deps.toSet(), module = this.namespace.split("::")[1], main = main) + val output = this.toString() + .shouldCompile(deps.toSet(), module = this.namespace.split("::")[1], main = main, strict = strict) + if (expectFailure) { + println(this.toString()) + } + return output } catch (e: CommandFailed) { // When the test fails, print the code for convenience - println(this.toString()) + if (!expectFailure) { + println(this.toString()) + } throw e } } -fun String.shouldCompile(deps: Set, module: String? = null, main: String = "") { +fun String.shouldCompile( + deps: Set, + module: String? = null, + main: String = "", + strict: Boolean = false +): String { this.shouldParseAsRust() val tempDir = createTempDir() // TODO: unify this with CargoTomlGenerator @@ -39,7 +51,7 @@ fun String.shouldCompile(deps: Set, module: String? = null, main version = "0.0.1" authors = ["rcoh@amazon.com"] edition = "2018" - + [dependencies] ${deps.joinToString("\n") { it.toString() }} """.trimIndent() @@ -48,22 +60,28 @@ fun String.shouldCompile(deps: Set, module: String? = null, main val mainRs = tempDir.resolve("src/main.rs") val testModule = tempDir.resolve("src/$module.rs") testModule.writeText(this) - testModule.appendText(""" + testModule.appendText( + """ #[test] fn test() { $main - } - """.trimIndent()) - mainRs.appendText(""" + } + """.trimIndent() + ) + mainRs.appendText( + """ pub mod $module; use crate::$module::*; fn main() { } - """.trimIndent()) + """.trimIndent() + ) "cargo check".runCommand(tempDir.toPath()) - if (main != "") { - "cargo test".runCommand(tempDir.toPath()) + val testOutput = "cargo test".runCommand(tempDir.toPath()) + if (strict) { + "cargo clippy -- -D warnings".runCommand(tempDir.toPath()) } + return testOutput } fun String.shouldCompile() { diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt index 8274a3e15afe9ec8b02ccf1796e0b559e0c56bdc..695804753576399cbbb41edf0a0730e9fb50fa78 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/TestHelpers.kt @@ -21,5 +21,5 @@ fun testSymbolProvider(model: Model): SymbolProvider = SymbolVisitor(model, "tes private val SmithyVersion = "1.0" fun String.asSmithy(sourceLocation: String? = null): Model { val processed = letIf(!this.startsWith("\$version")) { "\$version: ${SmithyVersion.dq()}\n$it" } - return Model.assembler().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble().unwrap() + return Model.assembler().discoverModels().addUnparsedModel(sourceLocation ?: "test.smithy", processed).assemble().unwrap() } diff --git a/rust-runtime/smithy-types/src/lib.rs b/rust-runtime/smithy-types/src/lib.rs index 491f255b6559db512718fead3acc7ff688aa8af8..158a9669cbe406ad11acf199fd9769aee4565d16 100644 --- a/rust-runtime/smithy-types/src/lib.rs +++ b/rust-runtime/smithy-types/src/lib.rs @@ -85,6 +85,12 @@ pub struct Blob { inner: Vec, } +impl Blob { + pub fn new(inp: Vec) -> Self { + Blob { inner: inp } + } +} + #[cfg(test)] mod test { use crate::instant::Format;