diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt index f5dc76bfed51f391adb0848822c4cb9d09b7fb98..85ef16cf3e209b90b65b2d9de3770a1d5e52fb9f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt @@ -18,10 +18,19 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.utils.CodeWriter -fun CodeWriter.withBlock(textBeforeNewLine: String, textAfterNewLine: String, block: CodeWriter.() -> Unit): CodeWriter { - openBlock(textBeforeNewLine) +fun CodeWriter.withBlock( + textBeforeNewLine: String, + textAfterNewLine: String, + conditional: Boolean = true, + block: CodeWriter.() -> Unit +): CodeWriter { + if (conditional) { + openBlock(textBeforeNewLine) + } block(this) - closeBlock(textAfterNewLine) + if (conditional) { + closeBlock(textAfterNewLine) + } return this } @@ -35,9 +44,30 @@ fun <T : CodeWriter> T.rustBlock(header: String, vararg args: Any, block: T.() - return this } -class RustWriter(filename: String, private val namespace: String, private val commentCharacter: String = "//") : CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) { +class RustWriter private constructor(private val filename: String, val namespace: String, private val commentCharacter: String = "//") : + CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) { + companion object { + fun forModule(module: String): RustWriter { + return RustWriter("$module.rs", "crate::$module") + } + + val Factory: CodegenWriterFactory<RustWriter> = + CodegenWriterFactory<RustWriter> { filename, namespace -> + when { + filename.endsWith(".toml") -> RustWriter(filename, namespace, "#") + else -> RustWriter(filename, namespace) + } + } + } + init { + if (filename.endsWith(".rs")) { + require(namespace.startsWith("crate")) { "We can only write into files in the crate (got $namespace)" } + } + } + private val formatter = RustSymbolFormatter() private var n = 0 + init { putFormatter('T', formatter) } @@ -47,6 +77,23 @@ class RustWriter(filename: String, private val namespace: String, private val co return "${prefix}_$n" } + /** + * Create an inline module. + * [header] should be the declaration of the module, eg. `pub mod Hello`. + * + * The returned writer will inject any local imports into the module as needed. + */ + fun withModule(moduleName: String, visibility: String = "pub", moduleWriter: RustWriter.() -> Unit) { + // In Rust, modules must specify their own imports—they don't have access to the parent scope. + // To easily handle this, create a new inner writer to collect imports, then dump it + // into an inline module. + val innerWriter = RustWriter(this.filename, "${this.namespace}::$moduleName") + moduleWriter(innerWriter) + rustBlock("$visibility mod $moduleName") { + write(innerWriter.toString()) + } + } + // TODO: refactor both of these methods & add a parent method to for_each across any field type // generically fun OptionForEach(member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) { @@ -99,7 +146,12 @@ class RustWriter(filename: String, private val namespace: String, private val co is RuntimeType -> { t.dependency?.also { addDependency(it) } // for now, use the fully qualified type name - "::${t.namespace}::${t.name}" + val prefix = if (t.namespace.startsWith("crate")) { + "" + } else { + "::" + } + "$prefix${t.namespace}::${t.name}" } is Symbol -> { if (t.namespace != namespace) { @@ -111,12 +163,4 @@ class RustWriter(filename: String, private val namespace: String, private val co } } } - - companion object { - val Factory: CodegenWriterFactory<RustWriter> = - CodegenWriterFactory<RustWriter> { filename, namespace -> when { - filename.endsWith(".toml") -> RustWriter(filename, namespace, "#") - else -> RustWriter(filename, namespace) - } } - } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt index 41342d155b7c2e93b50203442460034501253b1e..08b5b137607fcc9d3134d124d00dc7d8dde0897f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt @@ -65,7 +65,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() { ) cargoToml.render() } - writers.useFileWriter("src/lib.rs") { + writers.useFileWriter("src/lib.rs", "crate::lib") { // TODO: a more structured method of signaling what modules should get loaded. val modules = PublicModules.filter { writers.writers.containsKey("src/$it.rs") } LibRsGenerator(modules, it).render() diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt index 4e449c1bc132350b0dc59dbb510310cd9e7748f0..26b2792b491629890c32ad1fc45e029ee0d5969c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt @@ -69,10 +69,15 @@ fun Symbol.referenceClosure(): List<Symbol> { return listOf(this) + referencedSymbols.flatMap { it.referenceClosure() } } -data class SymbolVisitorConfig(val runtimeConfig: RuntimeConfig, val handleOptionality: Boolean = true, val handleRustBoxing: Boolean = true) +data class SymbolVisitorConfig( + val runtimeConfig: RuntimeConfig, + val handleOptionality: Boolean = true, + val handleRustBoxing: Boolean = true +) // TODO: consider if this is better handled as a wrapper -val DefaultConfig = SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true) +val DefaultConfig = + SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true) data class SymbolLocation(val filename: String, val namespace: String) @@ -84,6 +89,19 @@ val Shapes = SymbolLocation("model.rs", "model") val Errors = SymbolLocation("error.rs", "error") val Operations = SymbolLocation("operation.rs", "operation") +fun Symbol.makeOptional(): Symbol { + return if (isOptional()) { + this + } else { + val rustType = RustType.Option(this.rustType()) + Symbol.builder().rustType(rustType) + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } +} + class SymbolVisitor( private val model: Model, private val rootNamespace: String = "crate", @@ -102,16 +120,12 @@ class SymbolVisitor( } private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol { - val httpLabeledInput = container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java) + // If a field has the httpLabel trait and we are generating + // an Input shape, then the field is _not optional_. + val httpLabeledInput = + container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java) return if (nullableIndex.isNullable(member) && !httpLabeledInput) { - with(Symbol.builder()) { - val rustType = RustType.Option(symbol.rustType()) - rustType(rustType) - addReference(symbol) - name(rustType.name) - putProperty(SHAPE_KEY, member) - build() - } + symbol.makeOptional() } else symbol } @@ -128,7 +142,7 @@ class SymbolVisitor( } private fun simpleShape(shape: SimpleShape): Symbol { - return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build() + return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).canUseDefault().build() } override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape) @@ -239,7 +253,7 @@ class SymbolVisitor( return builder.rustType(rustType) .name(rustType.name) // Every symbol that actually gets defined somewhere should set a definition file - // If we ever generate a `thisisabug.rs`, we messed something up + // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation .definitionFile("thisisabug.rs") } } @@ -247,11 +261,24 @@ class SymbolVisitor( // TODO(chore): Move this to a useful place private const val RUST_TYPE_KEY = "rusttype" private const val SHAPE_KEY = "shape" +private const val CAN_USE_DEFAULT = "canusedefault" fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder { return this.putProperty(RUST_TYPE_KEY, rustType) } +fun Symbol.Builder.canUseDefault(value: Boolean = true): Symbol.Builder { + return this.putProperty(CAN_USE_DEFAULT, value) +} + +/** + * True when it is valid to use the default/0 value for [this] symbol during construction. + */ +fun Symbol.canUseDefault(): Boolean = this.getProperty(CAN_USE_DEFAULT, Boolean::class.javaObjectType).orElse(false) + +/** + * True when [this] is will be represented by Option<T> in Rust + */ fun Symbol.isOptional(): Boolean = when (this.rustType()) { is RustType.Option -> true else -> false diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt index d210df7275c82fb21218009ee418ea7291a21fdd..b7297fa090671938986cfb5a5e6652e29aeed034 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt @@ -10,18 +10,31 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.lang.RustType import software.amazon.smithy.rust.codegen.lang.RustWriter +import software.amazon.smithy.rust.codegen.lang.render +import software.amazon.smithy.rust.codegen.lang.rustBlock +import software.amazon.smithy.rust.codegen.lang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.canUseDefault +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.makeOptional +import software.amazon.smithy.rust.codegen.smithy.rustType +import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.utils.CaseUtils // TODO(maybe): extract struct generation from Smithy shapes to support generating body objects -// TODO: generate builders; 1d +// TODO: generate documentation class StructureGenerator( val model: Model, private val symbolProvider: SymbolProvider, private val writer: RustWriter, - private val shape: StructureShape + private val shape: StructureShape, + private val renderBuilder: Boolean = true ) { - private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) } + private val members: List<MemberShape> = shape.allMembers.values.toList() + private val structureSymbol = symbolProvider.toSymbol(shape) + private val builderSymbol = RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}") fun render() { renderStructure() val errorTrait = shape.getTrait(ErrorTrait::class.java) @@ -29,6 +42,12 @@ class StructureGenerator( val errorGenerator = ErrorGenerator(model, symbolProvider, writer, shape, it) errorGenerator.render() } + if (renderBuilder) { + val symbol = symbolProvider.toSymbol(shape) + writer.withModule(symbol.name.toSnakeCase()) { + renderBuilder(this) + } + } } private fun renderStructure() { @@ -36,11 +55,95 @@ class StructureGenerator( // TODO(maybe): Pull derive info from the symbol so that the symbol provider can alter things as necessary; 4h writer.write("#[non_exhaustive]") writer.write("#[derive(Debug, PartialEq, Clone)]") - val blockWriter = writer.openBlock("pub struct ${symbol.name} {") - sortedMembers.forEach { member -> - val memberName = symbolProvider.toMemberName(member) - blockWriter.write("pub $memberName: \$T,", symbolProvider.toSymbol(member)) } - blockWriter.closeBlock("}") + writer.rustBlock("pub struct ${symbol.name}") { + members.forEach { member -> + val memberName = symbolProvider.toMemberName(member) + write("pub $memberName: \$T,", symbolProvider.toSymbol(member)) + } + } + + if (renderBuilder) { + writer.rustBlock("impl ${symbol.name}") { + rustBlock("pub fn builder() -> \$T", builderSymbol) { + write("\$T::default()", builderSymbol) + } + } + } + } + + private fun renderBuilder(writer: RustWriter) { + // Eventually, I want to do a fancier module layout: + // model/some_model.rs [contains builder and impl for a single model] struct SomeModel, struct Builder + // model/mod.rs [contains pub use for each model to bring it into top level scope] + // users will do models::SomeModel, models::SomeModel::builder() + val builderName = "Builder" + writer.write("#[non_exhaustive]") + writer.write("#[derive(Debug, Clone, Default)]") + writer.rustBlock("pub struct $builderName") { + members.forEach { member -> + val memberName = symbolProvider.toMemberName(member) + // All fields in the builder are optional + val memberSymbol = symbolProvider.toSymbol(member).makeOptional() + // TODO: should the builder members be public? + write("$memberName: \$T,", memberSymbol) + } + } + + fun builderConverter(rustType: RustType) = when (rustType) { + is RustType.String -> "inp.into()" + else -> "inp" + } + + writer.rustBlock("impl $builderName") { + members.forEach { member -> + val memberName = symbolProvider.toMemberName(member) + // All fields in the builder are optional + val memberSymbol = symbolProvider.toSymbol(member) + val coreType = memberSymbol.rustType().let { + when (it) { + is RustType.Option -> it.value + else -> it + } + } + val signature = when (coreType) { + is RustType.String -> "<T: Into<String>>(mut self, inp: T) -> Self" + else -> "(mut self, inp: ${coreType.render()}) -> Self" + } + writer.rustBlock("pub fn $memberName$signature") { + write("self.$memberName = Some(${builderConverter(coreType)});") + write("self") + } + } + + val fallible = members.map { symbolProvider.toSymbol(it) }.any { + // If any members are not optional && we can't use a default, we need to + // generate a fallible builder + !it.isOptional() && !it.canUseDefault() + } + + val returnType = when (fallible) { + true -> "Result<\$T, String>" + false -> "\$T" + } + + writer.rustBlock("pub fn build(self) -> $returnType", structureSymbol) { + withBlock("Ok(", ")", conditional = fallible) { + rustBlock("\$T", structureSymbol) { + members.forEach { member -> + val memberName = symbolProvider.toMemberName(member) + val memberSymbol = symbolProvider.toSymbol(member) + val errorWhenMissing = "$memberName is required when building ${structureSymbol.name}" + val modifier = when { + !memberSymbol.isOptional() && memberSymbol.canUseDefault() -> ".unwrap_or_default()" + !memberSymbol.isOptional() -> ".ok_or(${errorWhenMissing.dq()})?" + else -> "" + } + write("$memberName: self.$memberName$modifier,") + } + } + } + } + } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt index 5719adfe0d4ffc79a1800d85063a0760a9eac18a..3dcf0693f3a609461dac7283f6e3a59324955c6f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Exec.kt @@ -23,8 +23,9 @@ fun String.runCommand(workdir: Path? = null): String? { proc.waitFor(60, TimeUnit.MINUTES) if (proc.exitValue() != 0) { - val output = proc.errorStream.bufferedReader().readText() - throw CommandFailed("Command Failed\n$output") + val stdErr = proc.errorStream.bufferedReader().readText() + val stdOut = proc.inputStream.bufferedReader().readText() + throw CommandFailed("Command Failed\n$stdErr\n$stdOut") } return proc.inputStream.bufferedReader().readText() } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/EnumGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/EnumGeneratorTest.kt index 204a8e4451088212c14f15904bc73863414d30e8..33d4e4d3c4ef2ec9837e63273b8e70240d33d919 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/EnumGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/EnumGeneratorTest.kt @@ -47,7 +47,7 @@ class EnumGeneratorTest { .assemble() .unwrap() val provider: SymbolProvider = SymbolVisitor(model, "test") - val writer = RustWriter("model.rs", "model") + val writer = RustWriter.forModule("model") val generator = EnumGenerator(provider, writer, shape, trait) generator.render() val result = writer.toString() diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt index 98ac52dde4705c887c7d44c66bf8be9d6c61308c..986119377276ad7df2edc7ba911bc7d2579bf5e0 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt @@ -108,26 +108,20 @@ class HttpTraitBindingGeneratorTest { httpTrait.uriFormatString() shouldBe ("/{bucketName}/{key}".dq()) } - // TODO: when we generate builders, use them to clean up these tests; 1h @Test fun `generate uris`() { - val writer = RustWriter("operation.rs", "operation") + val writer = RustWriter.forModule("operation") // currently rendering the operation renders the protocols—I want to separate that at some point. renderOperation(writer) writer.shouldCompile( """ let ts = Instant::from_epoch_seconds(10123125); - let inp = PutObjectInput { - additional: None, - bucket_name: "somebucket/ok".to_string(), - data: None, - date_header_list: None, - key: ts.clone(), - int_list: None, - extras: Some(vec![0, 1,2,44]), - some_value: Some("svq!!%&".to_string()), - media_type: None - }; + let inp = PutObjectInput::builder() + .bucket_name("somebucket/ok") + .key(ts.clone()) + .extras(vec![0,1,2,44]) + .some_value("svq!!%&") + .build().expect("build should succeed"); let mut o = String::new(); inp.uri_base(&mut o); assert_eq!(o.as_str(), "/somebucket%2Fok/1970-04-28T03:58:45Z"); @@ -140,22 +134,20 @@ class HttpTraitBindingGeneratorTest { @Test fun `build http requests`() { - val writer = RustWriter("operation.rs", "operation") + val writer = RustWriter.forModule("operation") renderOperation(writer) writer.shouldCompile( """ let ts = Instant::from_epoch_seconds(10123125); - let inp = PutObjectInput { - additional: None, - bucket_name: "buk".to_string(), - data: None, - date_header_list: Some(vec![ts.clone()]), - int_list: Some(vec![0,1,44]), - key: Instant::from_epoch_seconds(10123125), - extras: Some(vec![0,1]), - some_value: Some("qp".to_string()), - media_type: Some("base64encodethis".to_string()), - }; + let inp = PutObjectInput::builder() + .bucket_name("buk") + .date_header_list(vec![ts.clone()]) + .int_list(vec![0,1,44]) + .key(ts.clone()) + .extras(vec![0,1]) + .some_value("qp") + .media_type("base64encodethis") + .build().unwrap(); let http_request = inp.build_http_request().body(()).unwrap(); assert_eq!(http_request.uri(), "/buk/1970-04-28T03:58:45Z?paramName=qp&hello=0&hello=1"); assert_eq!(http_request.method(), "PUT"); diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/StructureGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/StructureGeneratorTest.kt index 7eeb072d53ab1168e0610aa84570224281981074..2782b656d6b03797f5fe96667fbd5dc01502a5c2 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/StructureGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/StructureGeneratorTest.kt @@ -6,76 +6,111 @@ package software.amazon.smithy.rust.codegen.generators import org.junit.jupiter.api.Test +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.traits.DocumentationTrait -import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.smithy.SymbolVisitor +import software.amazon.smithy.rust.codegen.smithy.canUseDefault import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.testutil.asSmithy import software.amazon.smithy.rust.testutil.shouldCompile import software.amazon.smithy.rust.testutil.testSymbolProvider class StructureGeneratorTest { - private val model: Model - private val struct: StructureShape - private val error: StructureShape - init { - val member1 = MemberShape.builder().id("com.test#MyStruct\$foo").target("smithy.api#String").build() - val member2 = MemberShape.builder().id("com.test#MyStruct\$bar").target("smithy.api#PrimitiveInteger").addTrait( - DocumentationTrait("This *is* documentation about the member.") - ).build() - val member3 = MemberShape.builder().id("com.test#MyStruct\$baz").target("smithy.api#Integer").build() - val member4 = MemberShape.builder().id("com.test#MyStruct\$ts").target("smithy.api#Timestamp").build() - - // struct 2 will be of type `Qux` under `MyStruct::quux` member - val struct2 = StructureShape.builder() - .id("com.test#Qux") - .build() - // structure member shape - note the capitalization of the member name (generated code should use the Kotlin class member name) - // val member4 = MemberShape.builder().id("com.test#MyStruct\$Quux").target(struct2).build() - val member5 = MemberShape.builder().id("com.test#MyStruct\$byteValue").target("smithy.api#Byte").build() - - struct = StructureShape.builder() - .id("com.test#MyStruct") - .addMember(member1) - .addMember(member2) - .addMember(member3) - .addMember(member4) - .addMember(member5) - .addTrait(DocumentationTrait("This *is* documentation about the shape.")) - .build() - - val messageMember = MemberShape.builder().id("com.test#MyError\$message").target("smithy.api#String").build() - - error = StructureShape.builder() - .id("com.test#MyError") - .addTrait(ErrorTrait("server")) - .addMember(messageMember).build() - model = Model.assembler() - .addShapes(struct, error, struct2, member1, member2, member3, messageMember) - .assemble() - .unwrap() - } + private val model = """ + namespace com.test + @documentation("this documents the shape") + structure MyStruct { + foo: String, + @documentation("This *is* documentation about the member.") + bar: PrimitiveInteger, + baz: Integer, + ts: Timestamp, + inner: Inner, + byteValue: Byte + } + + // Intentionally empty + structure Inner { + } + + @error("server") + structure MyError { + message: String + } + """.asSmithy() + private val struct = model.expectShape(ShapeId.from("com.test#MyStruct"), StructureShape::class.java) + private val inner = model.expectShape(ShapeId.from("com.test#Inner"), StructureShape::class.java) + private val error = model.expectShape(ShapeId.from("com.test#MyError"), StructureShape::class.java) @Test fun `generate basic structures`() { val provider: SymbolProvider = testSymbolProvider(model) - val writer = RustWriter("model.rs", "model") + val writer = RustWriter.forModule("model") + val innerGenerator = StructureGenerator(model, provider, writer, inner) val generator = StructureGenerator(model, provider, writer, struct) generator.render() + innerGenerator.render() writer.shouldCompile(""" let s: Option<MyStruct> = None; s.map(|i|println!("{:?}, {:?}", i.ts, i.byte_value)); - """.trimIndent()) + """.trimIndent() + ) + } + + @Test + fun `generate builders`() { + val provider: SymbolProvider = testSymbolProvider(model) + val writer = RustWriter.forModule("model") + val innerGenerator = StructureGenerator(model, provider, writer, inner) + val generator = StructureGenerator(model, provider, writer, struct) + generator.render() + innerGenerator.render() + writer.shouldCompile( + """ + let my_struct = MyStruct::builder().byte_value(4).foo("hello!").build(); + assert_eq!(my_struct.foo.unwrap(), "hello!"); + assert_eq!(my_struct.bar, 0); + """ + ) + } + + @Test + fun `generate fallible builders`() { + val baseProvider: SymbolProvider = testSymbolProvider(model) + val provider = + object : SymbolProvider { + override fun toSymbol(shape: Shape?): Symbol { + return baseProvider.toSymbol(shape).toBuilder().canUseDefault(false).build() + } + + override fun toMemberName(shape: MemberShape?): String { + return baseProvider.toMemberName(shape) + } + } + val writer = RustWriter.forModule("model") + val innerGenerator = StructureGenerator(model, provider, writer, inner) + val generator = StructureGenerator(model, provider, writer, struct) + generator.render() + innerGenerator.render() + writer.shouldCompile( + """ + let my_struct = MyStruct::builder().byte_value(4).foo("hello!").bar(0).build().expect("required field was not provided"); + assert_eq!(my_struct.foo.unwrap(), "hello!"); + assert_eq!(my_struct.bar, 0); + """ + ) + } @Test fun `generate error structures`() { val provider: SymbolProvider = SymbolVisitor(model, "test") - val writer = RustWriter("errors.rs", "errors") + val writer = RustWriter.forModule("error") val generator = StructureGenerator(model, provider, writer, error) generator.render() writer.shouldCompile() diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt index 686a571929893b212a2c803794a64bdc900355c4..0509d520e7c8501b68e96f3a1d66b2c3763b756b 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt @@ -39,7 +39,7 @@ class UnionGeneratorTest { .assemble() .unwrap() val provider: SymbolProvider = SymbolVisitor(model, "test") - val writer = RustWriter("model.rs", "model") + val writer = RustWriter.forModule("model") val generator = UnionGenerator(model, provider, writer, union) generator.render() val result = writer.toString() diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/lang/RustWriterTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/lang/RustWriterTest.kt index 92c7ab59b2defe2a0e8b155dd89cbb2e0405c174..a9321bb9ec123738065f30b914ddadda655fb95e 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/lang/RustWriterTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/lang/RustWriterTest.kt @@ -22,7 +22,7 @@ import software.amazon.smithy.rust.testutil.shouldParseAsRust class RustWriterTest { @Test fun `empty file`() { - val sut = RustWriter("empty.rs", "") + val sut = RustWriter.forModule("empty") sut.toString().shouldParseAsRust() sut.toString().shouldCompile() sut.toString().shouldMatchResource(javaClass, "empty.rs") @@ -30,7 +30,7 @@ class RustWriterTest { @Test fun `manually created struct`() { - val sut = RustWriter("lib.rs", "") + val sut = RustWriter.forModule("lib") val stringShape = StringShape.builder().id("test#Hello").build() val set = SetShape.builder() .id("foo.bar#Records") diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt index 32f6d29f3698f9c9f3d68099daea208e82fd8c29..91907000dd8bf31d81d0c8869ca7f0cbe613c094 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/testutil/Rust.kt @@ -21,7 +21,7 @@ fun String.shouldParseAsRust() { fun RustWriter.shouldCompile(main: String = "") { val deps = this.dependencies.map { RustDependency.fromSymbolDependency(it) } try { - this.toString().shouldCompile(deps.toSet(), main) + this.toString().shouldCompile(deps.toSet(), module = this.namespace.split("::")[1], main = main) } catch (e: CommandFailed) { // When the test fails, print the code for convenience println(this.toString()) @@ -29,7 +29,7 @@ fun RustWriter.shouldCompile(main: String = "") { } } -fun String.shouldCompile(deps: Set<RustDependency>, main: String = "") { +fun String.shouldCompile(deps: Set<RustDependency>, module: String? = null, main: String = "") { this.shouldParseAsRust() val tempDir = createTempDir() // TODO: unify this with CargoTomlGenerator @@ -46,13 +46,23 @@ fun String.shouldCompile(deps: Set<RustDependency>, main: String = "") { tempDir.resolve("Cargo.toml").writeText(cargoToml) tempDir.resolve("src").mkdirs() val mainRs = tempDir.resolve("src/main.rs") - mainRs.writeText(this) - if (!this.contains("fn main")) { - mainRs.appendText("\nfn main() { $main }\n") - } + val testModule = tempDir.resolve("src/$module.rs") + testModule.writeText(this) + testModule.appendText(""" + #[test] + fn test() { + $main + } + """.trimIndent()) + mainRs.appendText(""" + pub mod $module; + use crate::$module::*; + fn main() { + } + """.trimIndent()) "cargo check".runCommand(tempDir.toPath()) if (main != "") { - "cargo run".runCommand(tempDir.toPath()) + "cargo test".runCommand(tempDir.toPath()) } }