diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt new file mode 100644 index 0000000000000000000000000000000000000000..b3187599a1e1f9fdf810c6690b75273dfe3cb779 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGenerator.kt @@ -0,0 +1,777 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters + +import software.amazon.smithy.jmespath.ExpressionSerializer +import software.amazon.smithy.jmespath.JmespathExpression +import software.amazon.smithy.jmespath.ast.AndExpression +import software.amazon.smithy.jmespath.ast.BinaryExpression +import software.amazon.smithy.jmespath.ast.ComparatorExpression +import software.amazon.smithy.jmespath.ast.CurrentExpression +import software.amazon.smithy.jmespath.ast.ExpressionTypeExpression +import software.amazon.smithy.jmespath.ast.FieldExpression +import software.amazon.smithy.jmespath.ast.FilterProjectionExpression +import software.amazon.smithy.jmespath.ast.FlattenExpression +import software.amazon.smithy.jmespath.ast.FunctionExpression +import software.amazon.smithy.jmespath.ast.IndexExpression +import software.amazon.smithy.jmespath.ast.LiteralExpression +import software.amazon.smithy.jmespath.ast.MultiSelectHashExpression +import software.amazon.smithy.jmespath.ast.MultiSelectListExpression +import software.amazon.smithy.jmespath.ast.NotExpression +import software.amazon.smithy.jmespath.ast.ObjectProjectionExpression +import software.amazon.smithy.jmespath.ast.OrExpression +import software.amazon.smithy.jmespath.ast.ProjectionExpression +import software.amazon.smithy.jmespath.ast.SliceExpression +import software.amazon.smithy.jmespath.ast.Subexpression +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.SafeNamer +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.asRef +import software.amazon.smithy.rust.codegen.core.rustlang.plus +import software.amazon.smithy.rust.codegen.core.rustlang.render +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.orNull +import java.text.NumberFormat + +/** + * Contains information about the output of a visited [JmespathExpression]. + */ +data class GeneratedExpression( + /** The name of the identifier that this expression's evaluation is placed into */ + val identifier: String, + /** + * The Smithy shape for the output of the expression evaluation. + * + * This will be null for any computed values where there is no modeled shape + * for the result of the evaluated expression. For example, the multi-select + * list `['foo', 'baz']` is a list of string, but it isn't modeled anywhere, + * so there is no Smithy shape to represent it. + */ + val outputShape: Shape? = null, + /** + * The Rust type for the evaluated expression. + * + * For the most part, the code generator operates on output types rather than output shapes + * since there will always be an output type, whereas there will only sometimes be an output + * shape. Output shapes are only really used for handling enums and projections. + */ + val outputType: RustType, + /** Writable to output this expression's generated code. */ + val output: Writable, +) { + /** True if the type is a String, &str, or the shape is an enum shape. */ + fun isStringOrEnum(): Boolean = outputType.isString() || outputShape?.isEnumShape == true + + /** Dereferences this expression if it is a reference. */ + fun dereference(namer: SafeNamer): GeneratedExpression = + if (outputType is RustType.Reference) { + namer.safeName("_tmp").let { tmp -> + copy( + identifier = tmp, + outputType = outputType.member, + output = + output + + writable { + rust("let $tmp = *$identifier;") + }, + ) + } + } else { + this + } + + /** Converts this expression into a &str. */ + fun convertToStrRef(namer: SafeNamer): GeneratedExpression = + if (outputType is RustType.Reference && outputType.member is RustType.Reference) { + dereference(namer).convertToStrRef(namer) + } else if (!outputType.isString()) { + namer.safeName("_tmp").let { tmp -> + GeneratedExpression( + identifier = tmp, + outputType = RustType.String, + output = + output + + writable { + rust("let $tmp = $identifier.to_string();") + }, + ).convertToStrRef(namer) + } + } else if (!outputType.isStr()) { + namer.safeName("_tmp").let { tmp -> + GeneratedExpression( + identifier = tmp, + outputType = RustType.Reference(null, RustType.Opaque("str")), + output = + output + + writable { + rust("let $tmp = $identifier.as_str();") + }, + ).convertToStrRef(namer) + } + } else { + this + } + + /** Converts a number expression into a specific number type */ + fun convertToNumberPrimitive( + namer: SafeNamer, + desiredPrimitive: RustType, + ): GeneratedExpression { + check(outputType.isNumber() && desiredPrimitive.isNumber()) { + "this function only works on number types" + } + + return when { + desiredPrimitive is RustType.Reference -> convertToNumberPrimitive(namer, desiredPrimitive.member) + outputType is RustType.Reference -> dereference(namer).convertToNumberPrimitive(namer, desiredPrimitive) + outputType != desiredPrimitive -> + namer.safeName("_tmp").let { tmp -> + GeneratedExpression( + identifier = tmp, + outputType = desiredPrimitive, + output = + output + + writable { + rust("let $tmp = $identifier as ${desiredPrimitive.render()};") + }, + ) + } + else -> this + } + } +} + +/** + * Indicates a feature that's part of the JmesPath spec, but that we explicitly decided + * not to support in smithy-rs due to the complexity of code generating it for Rust. + */ +data class UnsupportedJmesPathException(private val msg: String) : RuntimeException(msg) + +/** Code can't be generated for the combination of the Smithy shape and the JmesPath expression. */ +data class InvalidJmesPathTraversalException(private val msg: String) : RuntimeException(msg) + +/** This indicates a bug in the code generator itself that should be fixed. */ +data class JmesPathTraversalCodegenBugException(private val msg: String) : RuntimeException(msg) + +/** + * Generates code from a JmesPath expression to traverse generated Smithy shapes. + * + * This generator implements a subset of the JmesPath spec since the full spec has more features + * than are needed for real-world Smithy waiters, and some of those features are very complex + * to code generate for Rust. + * + * Specifically, the following Jmespath features are supported: + * - Fields + * - Sub-expressions + * - Comparisons + * - Filter projections + * - Object projections + * - Multi-select lists (but only when every item in the list is the exact same type) + * - And/or/not boolean operations + * - Functions `contains` and `length`. The `keys` function may be supported in the future. + */ +class RustJmespathShapeTraversalGenerator( + codegenContext: ClientCodegenContext, +) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val safeNamer = SafeNamer() + + fun generate( + expr: JmespathExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + try { + val result = + when (expr) { + is ComparatorExpression -> generateComparator(expr, inputName, shape) + is FunctionExpression -> generateFunction(expr, inputName, shape) + is FieldExpression -> generateField(expr, inputName, shape) + is LiteralExpression -> generateLiteral(expr) + is MultiSelectListExpression -> generateMultiSelectList(expr, inputName, shape) + is AndExpression -> generateAnd(expr, inputName, shape) + is OrExpression -> generateOr(expr, inputName, shape) + is NotExpression -> generateNot(expr, inputName, shape) + is ObjectProjectionExpression -> generateObjectProjection(expr, inputName, shape) + is FilterProjectionExpression -> generateFilterProjection(expr, inputName, shape) + is ProjectionExpression -> generateProjection(expr, inputName, shape) + is Subexpression -> generateSubexpression(expr, inputName, shape) + is CurrentExpression -> throw JmesPathTraversalCodegenBugException("current expression must be handled in each expression type that can have one") + is ExpressionTypeExpression -> throw UnsupportedJmesPathException("Expression type expressions are not supported by smithy-rs") + is IndexExpression -> throw UnsupportedJmesPathException("Index expressions are not supported by smithy-rs") + is MultiSelectHashExpression -> throw UnsupportedJmesPathException("Multi-select hash expressions are not supported by smithy-rs") + is SliceExpression -> throw UnsupportedJmesPathException("Slice expressions are not supported by smithy-rs") + else -> throw UnsupportedJmesPathException("${expr.javaClass.name} expression type not supported by smithy-rs") + } + return result.copy( + output = + writable { + result.output(this) + if (debugMode) { + rust("// ${result.identifier} = ${ExpressionSerializer().serialize(expr)}") + } + }, + ) + } catch (ex: UnsupportedJmesPathException) { + throw ex.copy(msg = "${ex.message}\nExpression: ${ExpressionSerializer().serialize(expr)}") + } catch (ex: InvalidJmesPathTraversalException) { + throw ex.copy(msg = "${ex.message}\nExpression: ${ExpressionSerializer().serialize(expr)}") + } catch (ex: JmesPathTraversalCodegenBugException) { + throw ex.copy(msg = "${ex.message}\nExpression: ${ExpressionSerializer().serialize(expr)}") + } + } + + private fun generateComparator( + expr: ComparatorExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val left = generate(expr.left, inputName, shape) + val right = generate(expr.right, inputName, shape) + return generateCompare(left, right, expr.comparator.toString()) + } + + private fun generateCompare( + left: GeneratedExpression, + right: GeneratedExpression, + op: String, + ): GeneratedExpression = + if (left.outputType.isDoubleReference()) { + generateCompare(left.dereference(safeNamer), right, op) + } else if (right.outputType.isDoubleReference()) { + generateCompare(left, right.dereference(safeNamer), op) + } else { + safeNamer.safeName("_cmp").let { ident -> + return GeneratedExpression( + identifier = ident, + outputType = RustType.Bool, + output = + if (left.isStringOrEnum() && right.isStringOrEnum()) { + writable { + val leftStr = left.convertToStrRef(safeNamer).also { it.output(this) } + val rightStr = right.convertToStrRef(safeNamer).also { it.output(this) } + rust("let $ident = ${leftStr.identifier} $op ${rightStr.identifier};") + } + } else if (left.outputType.isNumber() && right.outputType.isNumber()) { + writable { + val leftPrim = + left.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) } + val rightPrim = + right.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) } + rust("let $ident = ${leftPrim.identifier} $op ${rightPrim.identifier};") + } + } else if (left.outputType.isBool() && right.outputType.isBool()) { + left.output + right.output + + writable { + rust("let $ident = ${left.identifier} $op ${right.identifier};") + } + } else { + throw UnsupportedJmesPathException("Comparison of ${left.outputType.render()} with ${right.outputType.render()} is not supported by smithy-rs") + }, + ) + } + } + + private fun generateFunction( + expr: FunctionExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val ident = safeNamer.safeName("_ret") + return when (expr.name) { + "length" -> { + if (expr.arguments.size != 1) { + throw InvalidJmesPathTraversalException("Length function takes exactly one argument") + } + val arg = generate(expr.arguments[0], inputName, shape) + if (!arg.outputType.isCollection() && !arg.outputType.isString()) { + throw InvalidJmesPathTraversalException("Argument to `length` function must be a collection or string type") + } + GeneratedExpression( + identifier = ident, + outputType = RustType.Integer(64), + output = + writable { + arg.output(this) + rust("let $ident = ${arg.identifier}.len() as i64;") + }, + ) + } + + "contains" -> { + if (expr.arguments.size != 2) { + throw InvalidJmesPathTraversalException("Contains function takes exactly two arguments") + } + val left = generate(expr.arguments[0], inputName, shape) + if (!left.outputType.isCollection() && !left.outputType.isString()) { + throw InvalidJmesPathTraversalException("First argument to `contains` function must be a collection or string type") + } + if (expr.arguments[1].isLiteralNull()) { + throw UnsupportedJmesPathException("Checking for null with `contains` is not supported in smithy-rs") + } + val right = generate(expr.arguments[1], inputName, shape) + if (!right.outputType.isNumber() && !right.outputType.isString() && right.outputShape?.isEnumShape != true) { + throw UnsupportedJmesPathException("Checking for anything other than numbers, strings, or enums in the `contains` function is not supported in smithy-rs") + } + if (left.outputType.isString()) { + return GeneratedExpression( + identifier = ident, + outputType = RustType.Bool, + output = + left.output + right.output + + writable { + if (right.outputType.isString()) { + rust("let $ident = ${left.identifier}.contains(${right.identifier});") + } else { + val tmp = safeNamer.safeName("_tmp") + rust("let $tmp = ${right.identifier}.to_string();") + rust("let $ident = ${left.identifier}.contains(&$tmp);") + } + }, + ) + } else { + return GeneratedExpression( + identifier = ident, + outputType = RustType.Bool, + output = + left.output + right.output + + writable { + withBlockTemplate("let $ident = ${left.identifier}.iter().any(|_v| {", "});") { + val compare = + generateCompare( + GeneratedExpression( + identifier = "_v", + outputShape = + (left.outputShape as? CollectionShape)?.member?.target?.let { + model.expectShape( + it, + ) + }, + outputType = + RustType.Reference( + lifetime = null, + member = left.outputType.collectionValue(), + ), + output = writable {}, + ), + // Clear the output since we already wrote the right and don't want to duplicate it + right.copy(output = writable {}), + "==", + ).also { it.output(this) } + + rust(compare.identifier) + } + }, + ) + } + } + + else -> throw UnsupportedJmesPathException("The `${expr.name}` function is not supported by smithy-rs") + } + } + + private fun generateField( + expr: FieldExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + if (shape is StructureShape) { + val member = + shape.getMember(expr.name).orNull() + ?: throw InvalidJmesPathTraversalException("Member `${expr.name}` doesn't exist on ${shape.id}") + val memberSym = symbolProvider.toSymbol(member) + + val target = model.expectShape(member.target) + val targetSym = symbolProvider.toSymbol(target) + + val ident = safeNamer.safeName("_fld") + return GeneratedExpression( + identifier = ident, + outputShape = target, + outputType = targetSym.rustType().asRef(), + output = + writable { + rust( + if (memberSym.isOptional()) { + "let $ident = $inputName.${memberSym.name}.as_ref()?;" + } else { + "let $ident = &$inputName.${memberSym.name};" + }, + ) + }, + ) + } else { + throw InvalidJmesPathTraversalException("Cannot look up fields in non-struct shapes") + } + } + + private fun generateLiteral(expr: LiteralExpression): GeneratedExpression { + val outputType = + when (expr.value) { + is Boolean -> RustType.Reference(lifetime = null, member = RustType.Bool) + is Double -> RustType.Reference(lifetime = null, member = RustType.Float(64)) + is String -> RustType.Reference(lifetime = null, member = RustType.Opaque("str")) + null -> throw UnsupportedJmesPathException("Literal nulls are not supported by smithy-rs") + else -> throw UnsupportedJmesPathException("Literal expression '${ExpressionSerializer().serialize(expr)}' is not supported by smithy-rs") + } + + fun fmtFloating(floating: Number) = + NumberFormat.getInstance().apply { minimumFractionDigits = 1 }.format(floating) + + return safeNamer.safeName("_lit").uppercase().let { ident -> + GeneratedExpression( + identifier = ident, + outputType = outputType, + output = + writable { + when (val value = expr.value) { + is Boolean -> rust("const $ident: &bool = &$value;") + is Double -> { + rust("const $ident: #T = &${fmtFloating(value)};", outputType) + } + + is String -> rust("const $ident: &str = ${value.dq()};") + else -> throw RuntimeException("unreachable") + } + }, + ) + } + } + + private fun generateMultiSelectList( + expr: MultiSelectListExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val expressions = + expr.expressions.map { subexpr -> + generate(subexpr, inputName, shape) + } + // If we wanted to support mixed-types, we would need to use tuples, add tuple support to RustType, + // and update supported functions such as `contains` to operate on tuples. + for (pair in expressions.map { it.outputType }.windowed(2)) { + if (pair[0] != pair[1]) { + throw UnsupportedJmesPathException("Mixed-type multi-select lists are not supported by smithy-rs") + } + } + + return safeNamer.safeName("_msl").let { ident -> + GeneratedExpression( + identifier = ident, + outputType = RustType.Vec(expressions[0].outputType), + output = + writable { + expressions.forEach { it.output(this) } + rust("let $ident = vec![${expressions.map { it.identifier }.joinToString(", ")}];") + }, + ) + } + } + + private fun generateAnd( + expr: AndExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression = generateBooleanOp(expr, "&&", inputName, shape) + + private fun generateOr( + expr: OrExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression = generateBooleanOp(expr, "||", inputName, shape) + + private fun generateBooleanOp( + expr: BinaryExpression, + op: String, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val left = generate(expr.left, inputName, shape) + val right = generate(expr.right, inputName, shape) + if (!left.outputType.isBool() || !right.outputType.isBool()) { + throw UnsupportedJmesPathException("Applying the `$op` operation doesn't support non-boolean types in smithy-rs") + } + + return safeNamer.safeName("_bo").let { ident -> + GeneratedExpression( + identifier = ident, + outputType = RustType.Bool, + output = + writable { + val leftBool = left.dereference(safeNamer).also { it.output(this) } + val rightBool = right.dereference(safeNamer).also { it.output(this) } + rust("let $ident = ${leftBool.identifier} $op ${rightBool.identifier};") + }, + ) + } + } + + private fun generateNot( + expr: NotExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val inner = generate(expr.expression, inputName, shape) + if (!inner.outputType.isBool()) { + throw UnsupportedJmesPathException("Negation of a non-boolean type is not supported by smithy-rs") + } + + return safeNamer.safeName("_not").let { ident -> + GeneratedExpression( + identifier = ident, + outputType = RustType.Bool, + output = + inner.output + + writable { + rust("let $ident = !${inner.identifier};") + }, + ) + } + } + + private fun generateProjection( + expr: ProjectionExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val maybeFlatten = expr.left + if (maybeFlatten is SliceExpression) { + throw UnsupportedJmesPathException("Slice expressions are not supported by smithy-rs") + } + if (maybeFlatten !is FlattenExpression) { + throw UnsupportedJmesPathException("Only projection expressions with flattens are supported by smithy-rs") + } + val left = generate(maybeFlatten.expression, inputName, shape) + val leftTarget = + when (val outputShape = left.outputShape) { + is ListShape -> model.expectShape(outputShape.member.target) + else -> throw InvalidJmesPathTraversalException("Left side of the flatten projection MUST resolve to a list or set shape") + } + val leftTargetSym = symbolProvider.toSymbol(leftTarget) + + // Short-circuit in the case where the projection is unnecessary + if (left.outputType.isCollection() && expr.right is CurrentExpression) { + return left + } + + val right = generate(expr.right, "v", leftTarget) + val projectionType = RustType.Vec(right.outputType.asRef()) + + return safeNamer.safeName("_prj").let { ident -> + GeneratedExpression( + identifier = ident, + outputShape = right.outputShape, + outputType = projectionType, + output = + left.output + + writable { + rustBlock("let $ident = ${left.identifier}.iter().flat_map(") { + rustBlockTemplate( + "fn map(v: &#{Left}) -> #{Option}<#{Right}>", + *preludeScope, + "Left" to leftTargetSym, + "Right" to right.outputType, + ) { + right.output(this) + rustTemplate("#{Some}(${right.identifier})", *preludeScope) + } + rust("map") + } + rustTemplate(").collect::<#{Vec}<_>>();", *preludeScope) + }, + ) + } + } + + private fun generateFilterProjection( + expr: FilterProjectionExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val left = generate(expr.left, inputName, shape) + if (!left.outputType.isList() && !left.outputType.isSet()) { + throw UnsupportedJmesPathException("Filter projections can only be done on lists or sets in smithy-rs") + } + + val leftTarget = model.expectShape((left.outputShape as ListShape).member.target) + val leftTargetSym = symbolProvider.toSymbol(leftTarget) + + val right = + if (expr.right is CurrentExpression) { + left.copy( + outputType = left.outputType.collectionValue().asRef(), + output = writable {}, + ) + } else { + generate(expr.right, "_v", leftTarget) + } + + val comparison = generate(expr.comparison, "_v", leftTarget) + if (!comparison.outputType.isBool()) { + throw InvalidJmesPathTraversalException("The filter expression comparison must result in a boolean") + } + + return safeNamer.safeName("_fprj").let { ident -> + GeneratedExpression( + identifier = ident, + outputShape = null, + outputType = RustType.Vec(right.outputType), + output = + left.output + + writable { + rust("let $ident = ${left.identifier}.iter()") + withBlock(".filter({", "})") { + rustBlockTemplate("fn filter(_v: &#{Arg}) -> #{Option}", "Arg" to leftTargetSym, *preludeScope) { + val toBool = comparison.dereference(safeNamer).also { it.output(this) } + rustTemplate("#{Some}(${toBool.identifier})", *preludeScope) + } + rust("|v| filter(v).unwrap_or_default()") + } + if (expr.right !is CurrentExpression) { + withBlock(".flat_map({", "})") { + rustBlockTemplate( + "fn map(_v: &#{Left}) -> #{Option}<#{Right}>", + *preludeScope, + "Left" to leftTargetSym, + "Right" to right.outputType, + ) { + right.output(this) + rustTemplate("#{Some}(${right.identifier})", *preludeScope) + } + rust("map") + } + } + rustTemplate(".collect::<#{Vec}<_>>();", *preludeScope) + }, + ) + } + } + + private fun generateObjectProjection( + expr: ObjectProjectionExpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + if (expr.left is CurrentExpression) { + throw UnsupportedJmesPathException("Object projection cannot be done on computed maps in smithy-rs") + } + val left = generate(expr.left, inputName, shape) + if (!left.outputType.isMap()) { + throw UnsupportedJmesPathException("Object projection is only supported on map types in smithy-rs") + } + if (left.outputShape == null) { + throw UnsupportedJmesPathException("Object projection cannot be done on computed maps in smithy-rs") + } + + val leftTarget = model.expectShape((left.outputShape as MapShape).value.target) + val leftTargetSym = symbolProvider.toSymbol(leftTarget) + + val right = + if (expr.right is CurrentExpression) { + left.copy( + outputType = left.outputType.collectionValue().asRef(), + output = writable {}, + ) + } else { + generate(expr.right, "_v", leftTarget) + } + + val ident = safeNamer.safeName("_oprj") + return GeneratedExpression( + identifier = ident, + outputShape = null, + outputType = RustType.Vec(right.outputType), + output = + left.output + + writable { + if (expr.right is CurrentExpression) { + rustTemplate("let $ident = ${left.identifier}.values().collect::<#{Vec}<_>>();", *preludeScope) + } else { + rustBlock("let $ident = ${left.identifier}.values().flat_map(") { + rustBlockTemplate( + "fn map(_v: &#{Left}) -> #{Option}<#{Right}>", + *preludeScope, + "Left" to leftTargetSym, + "Right" to right.outputType, + ) { + right.output(this) + rustTemplate("#{Some}(${right.identifier})", *preludeScope) + } + rust("map") + } + rustTemplate(").collect::<#{Vec}<_>>();", *preludeScope) + } + }, + ) + } + + private fun generateSubexpression( + expr: Subexpression, + inputName: String, + shape: Shape, + ): GeneratedExpression { + val left = generate(expr.left, inputName, shape) + val right = generate(expr.right, left.identifier, left.outputShape!!) + return GeneratedExpression( + identifier = right.identifier, + outputShape = right.outputShape, + outputType = right.outputType, + output = left.output + right.output, + ) + } +} + +private fun RustType.dereference(): RustType = + if (this is RustType.Reference) { + this.member.dereference() + } else { + this + } + +private fun RustType.isList(): Boolean = this.dereference() is RustType.Vec + +private fun RustType.isSet(): Boolean = this.dereference() is RustType.HashSet + +private fun RustType.isMap(): Boolean = this.dereference() is RustType.HashMap + +private fun RustType.isCollection(): Boolean = isList() || isSet() || isMap() + +private fun RustType.isString(): Boolean = this.dereference().let { it is RustType.String || it.isStr() } + +private fun RustType.isStr(): Boolean = this.dereference().let { it is RustType.Opaque && it.name == "str" } + +private fun RustType.isNumber(): Boolean = this.dereference().let { it is RustType.Integer || it is RustType.Float } + +private fun RustType.isBool(): Boolean = this.dereference() is RustType.Bool + +private fun RustType.isDoubleReference(): Boolean = this is RustType.Reference && this.member is RustType.Reference + +private fun RustType.collectionValue(): RustType = + when (this) { + is RustType.Reference -> member.collectionValue() + is RustType.Vec -> member + is RustType.HashSet -> member + is RustType.HashMap -> member + else -> throw RuntimeException("expected collection type") + } + +private fun JmespathExpression.isLiteralNull(): Boolean = this == LiteralExpression.NULL diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..0bc4385b312fcd23d096686d5cbc39f2de15e1f8 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/waiters/RustJmespathShapeTraversalGeneratorTest.kt @@ -0,0 +1,591 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.fail +import software.amazon.smithy.jmespath.JmespathExpression +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.core.util.outputShape + +class RustJmespathShapeTraversalGeneratorTest { + private class TestCase( + private val codegenContext: ClientCodegenContext, + private val rustCrate: RustCrate, + ) { + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + + private val outputShape = + model.lookup("test#TestOperation") + .outputShape(model) + val output = symbolProvider.toSymbol(outputShape) + val entityPrimitives = symbolProvider.toSymbol(model.lookup("test#EntityPrimitives")) + val entityLists = symbolProvider.toSymbol(model.lookup("test#EntityLists")) + val entityMaps = symbolProvider.toSymbol(model.lookup("test#EntityMaps")) + val enum = symbolProvider.toSymbol(model.lookup("test#Enum")) + val struct = symbolProvider.toSymbol(model.lookup("test#Struct")) + + val testDataFn1: RuntimeType + + init { + testDataFn1 = + RuntimeType.forInlineFun("test_data_1", ClientRustModule.root) { + rustTemplate( + """ + ##[allow(dead_code)] + fn test_data_1() -> #{Output} { + let primitives = #{EntityPrimitives}::builder() + .required_boolean(true) + .required_string("required-test") + .boolean(true) + .string("test") + .byte(1) + .short(2) + .integer(4) + .long(8) + .float(4.0) + .double(8.0) + .r##enum(#{Enum}::One) + .int_enum(1) + .build() + .unwrap(); + #{Output}::builder() + .primitives(primitives.clone()) + .lists(#{EntityLists}::builder() + .shorts(1).shorts(2) + .integers(3).integers(4) + .longs(5).longs(6) + .floats(7.0).floats(8.0) + .doubles(9.0).doubles(10.0) + .strings("one").strings("two") + .enums(#{Enum}::Two) + .int_enums(2) + .structs(#{Struct}::builder() + .primitives(primitives.clone()) + .build()) + .build()) + .maps(#{EntityMaps}::builder() + .strings("foo", "foo_oo") + .strings("bar", "bar_ar") + .booleans("foo", true) + .booleans("bar", false) + .structs("foo", #{Struct}::builder().integer(5).build()) + .build()) + .build() + } + """, + "Output" to output, + "EntityPrimitives" to entityPrimitives, + "EntityLists" to entityLists, + "EntityMaps" to entityMaps, + "Enum" to enum, + "Struct" to struct, + ) + } + } + + fun testCase( + testName: String, + expression: String, + testData: RuntimeType, + assertions: RustWriter.() -> Unit, + ) { + val generator = RustJmespathShapeTraversalGenerator(codegenContext) + val parsed = JmespathExpression.parse(expression) + val generated = generator.generate(parsed, "_output", outputShape) + rustCrate.unitTest(testName) { + rust("// jmespath: $expression") + rust("// jmespath parsed: $parsed") + rustBlockTemplate( + "fn inner(_output: &#{Arg}) -> #{Option}<#{Ret}>", + "Arg" to output, + "Ret" to generated.outputType, + *preludeScope, + ) { + generated.output(this) + rustTemplate("#{Some}(${generated.identifier})", *preludeScope) + } + rustTemplate("let output = #{test_data}();", "test_data" to testData) + rust("""println!("test data: {output:##?}");""") + rust("""println!("jmespath: {}", ${expression.dq()});""") + rust("let result = inner(&output);") + rust("""println!("result: {result:##?}");""") + rust("let result = result.unwrap();") + assertions() + // Unused variable suppression + rust("let _ = result;") + } + } + + fun invalid( + expression: String, + contains: String, + ) { + try { + val generator = RustJmespathShapeTraversalGenerator(codegenContext) + val parsed = JmespathExpression.parse(expression) + generator.generate(parsed, "_output", outputShape).output(RustWriter.forModule("unsupported")) + fail("expression '$expression' should have thrown InvalidJmesPathTraversalException") + } catch (ex: InvalidJmesPathTraversalException) { + ex.message shouldContain contains + } + } + + fun unsupported( + expression: String, + contains: String, + ) { + try { + val generator = RustJmespathShapeTraversalGenerator(codegenContext) + val parsed = JmespathExpression.parse(expression) + generator.generate(parsed, "_output", outputShape).output(RustWriter.forModule("unsupported")) + fail("expression '$expression' should have thrown UnsupportedJmesPathException") + } catch (ex: UnsupportedJmesPathException) { + ex.message shouldContain contains + } + } + } + + private fun integrationTest(testCases: TestCase.() -> Unit) { + clientIntegrationTest(testModel()) { codegenContext, rustCrate -> + TestCase(codegenContext, rustCrate).testCases() + } + } + + private fun simple(assertion: String): RustWriter.() -> Unit = + { + rust(assertion) + } + + private val expectFalse = simple("assert_eq!(false, result);") + private val expectTrue = simple("assert!(result);") + private val itCompiles = simple("") + + @Test + fun all() = + integrationTest { + fieldExpressions() + subExpressions() + flattenExpressions() + literalTypes() + functions() + comparisons() + objectProjections() + filterProjections() + booleanOperations() + multiSelectLists() + complexCombinationsOfFeatures() + + unsupported("&('foo')", "Expression type expressions") + unsupported("lists.integers[0]", "Index expressions") + unsupported("""{"foo": primitives, "bar": integer}""", "Multi-select hash expressions") + unsupported("lists.integers[0:2]", "Slice expressions") + } + + private fun TestCase.fieldExpressions() { + fun test( + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_field_$expression", expression, testDataFn1, assertions) + + test("primitives") { + rust("assert!(std::ptr::eq(output.primitives.as_ref().unwrap(), result));") + rust("""assert_eq!("test", result.string.as_deref().unwrap());""") + } + test("lists") { + rust("assert!(std::ptr::eq(output.lists.as_ref().unwrap(), result));") + } + test("maps") { + rust("assert!(std::ptr::eq(output.maps.as_ref().unwrap(), result));") + } + + invalid("doesNotExist", "Member `doesNotExist` doesn't exist") + } + + private fun TestCase.subExpressions() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_subexpression_$name", expression, testDataFn1, assertions) + + test("boolean", "primitives.boolean", expectTrue) + test("string", "primitives.string", simple("assert_eq!(\"test\", result);")) + test("byte", "primitives.byte", simple("assert_eq!(1i8, *result);")) + test("short", "primitives.short", simple("assert_eq!(2i16, *result);")) + test("integer", "primitives.integer", simple("assert_eq!(4i32, *result);")) + test("long", "primitives.long", simple("assert_eq!(8i64, *result);")) + test("float", "primitives.float", simple("assert_eq!(4f32, *result);")) + test("double", "primitives.double", simple("assert_eq!(8f64, *result);")) + test("enum", "primitives.enum") { + rust("assert_eq!(#T::One, *result);", enum) + } + test("int_enum", "primitives.intEnum", simple("assert_eq!(1, *result);")) + + invalid("primitives.integer.foo", "Cannot look up fields in non-struct shapes") + + test("required_boolean", "primitives.requiredBoolean", expectTrue) + test("required_string", "primitives.requiredString", simple("assert_eq!(\"required-test\", result);")) + } + + private fun TestCase.flattenExpressions() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_flatten_$name", expression, testDataFn1, assertions) + + test("shortcircuit", "lists.structs[]") { + rust("assert!(std::ptr::eq(output.lists.as_ref().unwrap().structs.as_ref().unwrap(), result));") + } + test("no_shortcircuit", "lists.structs[].primitives.string") { + rust("assert_eq!(1, result.len());") + rust("assert_eq!(\"test\", result[0]);") + } + + invalid("primitives.integer[]", "Left side of the flatten") + } + + private fun TestCase.literalTypes() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_literal_$name", expression, testDataFn1, assertions) + + test("bool", "`true`", expectTrue) + test("int", "`0`", simple("assert_eq!(0f64, *result);")) + test("float", "`1.5`", simple("assert_eq!(1.5f64, *result);")) + test("string", "`\"foo\"`", simple("assert_eq!(\"foo\", result);")) + + unsupported("`null`", "Literal nulls") + unsupported("`{}`", "Literal expression '`{}`'") + unsupported("`[]`", "Literal expression '`[]`'") + } + + private fun TestCase.functions() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_fn_$name", expression, testDataFn1, assertions) + + test("list_length", "length(lists.structs[])", simple("assert_eq!(1, result);")) + test("string_length", "length(primitives.string)", simple("assert_eq!(4, result);")) + + test("string_contains_false", "contains(primitives.string, 'foo')", expectFalse) + test("string_contains_true", "contains(primitives.string, 'st')", expectTrue) + + test("strings_contains_false", "contains(lists.strings, 'foo')", expectFalse) + test("strings_contains_true", "contains(lists.strings, 'two')", expectTrue) + + test("i16s_contains_false", "contains(lists.shorts, `0`)", expectFalse) + test("i16s_contains_true", "contains(lists.shorts, `1`)", expectTrue) + + test("i32s_contains_false", "contains(lists.integers, `0`)", expectFalse) + test("i32s_contains_true", "contains(lists.integers, `3`)", expectTrue) + + test("i64s_contains_false", "contains(lists.longs, `0`)", expectFalse) + test("i64s_contains_true", "contains(lists.longs, `5`)", expectTrue) + + test("f32s_contains_false", "contains(lists.floats, `0`)", expectFalse) + test("f32s_contains_true", "contains(lists.floats, `7.0`)", expectTrue) + + test("f64s_contains_false", "contains(lists.doubles, `0`)", expectFalse) + test("f64s_contains_true", "contains(lists.doubles, `9.0`)", expectTrue) + + test("enums_contains_false", "contains(lists.enums, 'one')", expectFalse) + test("enums_contains_true", "contains(lists.enums, 'two')", expectTrue) + + test("intenums_contains_false", "contains(lists.intEnums, `1`)", expectFalse) + test("intenums_contains_true", "contains(lists.intEnums, `2`)", expectTrue) + + test("stringlit_contains_stringlit_false", "contains('foo', 'o0')", expectFalse) + test("stringlit_contains_stringlit_true", "contains('foo', 'oo')", expectTrue) + + test("strings_contains_string", "contains(lists.strings, primitives.string)", expectFalse) + test("i32s_contains_i32", "contains(lists.integers, primitives.integer)", expectTrue) + test("i32s_contains_i16", "contains(lists.integers, primitives.short)", expectFalse) + test("f32s_contains_f32", "contains(lists.floats, primitives.float)", expectFalse) + + invalid("length()", "Length function takes exactly one argument") + invalid("length(primitives.integer)", "Argument to `length` function") + invalid("contains('foo')", "Contains function takes exactly two arguments") + invalid("contains(primitives.integer, 'foo')", "First argument to `contains`") + unsupported("contains(lists.structs, `null`)", "Checking for null with `contains`") + unsupported("contains(lists.structs, lists)", "Checking for anything other than") + unsupported("abs(`-1`)", "The `abs` function is not supported") + unsupported("contains(lists.floats, primitives.string)", "Comparison of &f32 with &::std::string::String") + } + + private fun TestCase.comparisons() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_compare_$name", expression, testDataFn1, assertions) + + test("eq_boollit_w_boollit", "`true` == `true`", expectTrue) + test("neq_boollit_w_boollit", "`true` != `true`", expectFalse) + test("boollit_w_boollit", "`true` != `true`", expectFalse) + test("bool_w_boollit", "primitives.boolean != `true`", expectFalse) + test("bool_w_bool", "primitives.boolean == primitives.boolean", expectTrue) + test("eq_integerlit_w_integerlit", "`0` == `0`", expectTrue) + test("neq_integerlit_w_integerlit", "`0` != `0`", expectFalse) + test("lt_integerlit_w_integerlit_false", "`0` < `0`", expectFalse) + test("lt_integerlit_w_integerlit_true", "`0` < `1`", expectTrue) + test("integer_w_integerlit", "primitives.integer != `0`", expectTrue) + test("integer_w_integer", "primitives.integer == primitives.integer", expectTrue) + test("float_w_integer_true", "primitives.float == primitives.integer", expectTrue) + test("integer_w_float_true", "primitives.integer == primitives.float", expectTrue) + test("float_w_integer_false", "primitives.float != primitives.integer", expectFalse) + test("integer_w_float_false", "primitives.integer != primitives.float", expectFalse) + test("eq_stringlit_w_stringlit", "'foo' == 'foo'", expectTrue) + test("neq_stringlit_w_stringlit", "'bar' != 'foo'", expectTrue) + test("string_w_stringlit_false", "primitives.string == 'foo'", expectFalse) + test("string_w_stringlit_true", "primitives.string == 'test'", expectTrue) + test("string_w_string", "primitives.string == primitives.string", expectTrue) + test("enum_w_stringlit_false", "primitives.enum == 'one'", expectTrue) + test("enum_w_stringlit_true", "primitives.enum == 'two'", expectFalse) + test("enum_w_string", "primitives.enum == primitives.string", expectFalse) + test("fn_w_number", "length(lists.structs[]) > `0`", expectTrue) + + unsupported("'foo' == `1`", "Comparison of &str with &f64") + unsupported("primitives.string == primitives.integer", "Comparison of &::std::string::String with &i32") + } + + private fun TestCase.objectProjections() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_obj_projection_$name", expression, testDataFn1, assertions) + + test("traverse_obj_projection_simple", "maps.booleans.*") { + rust("assert_eq!(2, result.len());") + // Order is non-deterministic because we're getting the values of a hash map + rust("assert_eq!(1, result.iter().filter(|&&&b| b == true).count());") + rust("assert_eq!(1, result.iter().filter(|&&&b| b == false).count());") + } + test("traverse_obj_projection_continued", "maps.structs.*.integer") { + rust("assert_eq!(1, result.len());") + rust("assert_eq!(5, **result.get(0).unwrap());") + } + test("traverse_obj_projection_complex", "length(maps.structs.*.strings) == `0`", expectTrue) + + unsupported("primitives.integer.*", "Object projection is only supported on map types") + unsupported("lists.structs[?`true`].*", "Object projection cannot be done on computed maps") + } + + private fun TestCase.filterProjections() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_filter_projection_$name", expression, testDataFn1, assertions) + + test("boollit", "lists.structs[?`true`]") { + rust("assert_eq!(1, result.len());") + } + test("intcmp", "lists.structs[?primitives.integer > `0`]") { + rust("assert_eq!(1, result.len());") + } + test("boollit_continued_empty", "lists.structs[?`true`].integer") { + rust("assert_eq!(0, result.len());") + } + test("boollit_continued", "lists.structs[?`true`].primitives.integer") { + rust("assert_eq!(1, result.len());") + } + test("intcmp_continued", "lists.structs[?primitives.integer > `0`].primitives.integer") { + rust("assert_eq!(1, result.len());") + rust("assert_eq!(4, **result.get(0).unwrap());") + } + test("intcmp_continued_filtered", "lists.structs[?primitives.integer == `0`].primitives.integer") { + rust("assert_eq!(0, result.len());") + } + + unsupported("primitives.integer[?`true`]", "Filter projections can only be done on lists") + invalid("lists.structs[?`5`]", "The filter expression comparison must result in a bool") + } + + private fun TestCase.booleanOperations() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_boolean_ops_$name", expression, testDataFn1, assertions) + + test("lit_not", "!`true`", expectFalse) + test("bool_not", "!(primitives.boolean)", expectFalse) + test("lit_and_lit", "`true` && `false`", expectFalse) + test("lit_or_lit", "`true` || `false`", expectTrue) + test("bool_and_lit", "primitives.boolean && `true`", expectTrue) + test("bool_or_lit", "primitives.boolean || `false`", expectTrue) + test("bool_and_bool", "primitives.boolean && primitives.boolean", expectTrue) + test("bool_or_bool", "primitives.boolean || primitives.boolean", expectTrue) + test("paren_expressions", "(`true` || `false`) && `true`", expectTrue) + + unsupported("`5` || `true`", "Applying the `||` operation doesn't support non-bool") + unsupported("`5` && `true`", "Applying the `&&` operation doesn't support non-bool") + unsupported("!`5`", "Negation of a non-boolean type") + } + + private fun TestCase.multiSelectLists() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_multiselectlists_$name", expression, testDataFn1, assertions) + + test("intlist_contains", "contains([`1`, `2`, `3`], `1`)", expectTrue) + test("stringlist_contains", "contains(['foo', 'bar'], 'foo')", expectTrue) + test("primitive_int_list_contains", "contains(primitives.[integer, integer], primitives.integer)", expectTrue) + test("primitive_bools", "primitives.[boolean, boolean]") { + rust("assert_eq!(2, result.len());") + rust("assert!(*result[0]);") + rust("assert!(*result[1]);") + } + test("primitive_strings_contain", "contains(primitives.[string, string], primitives.string)", expectTrue) + } + + private fun TestCase.complexCombinationsOfFeatures() { + fun test( + name: String, + expression: String, + assertions: RustWriter.() -> Unit, + ) = testCase("traverse_complex_combos_$name", expression, testDataFn1, assertions) + + test( + "1", + "(length(lists.structs[?!(integer < `0`) && integer >= `0` || `false`]) == `5`) == contains(lists.integers, length(maps.structs.*.strings))", + itCompiles, + ) + } + + private fun testModel() = + """ + ${'$'}version: "2" + namespace test + + @aws.protocols#awsJson1_0 + service TestService { + operations: [TestOperation], + } + + operation TestOperation { + input: GetEntityRequest, + output: GetEntityResponse, + errors: [], + } + + structure GetEntityRequest { + @required + name: String + } + + structure GetEntityResponse { + primitives: EntityPrimitives, + lists: EntityLists, + maps: EntityMaps, + } + + structure EntityPrimitives { + boolean: Boolean, + string: String, + byte: Byte, + short: Short, + integer: Integer, + long: Long, + float: Float, + double: Double, + enum: Enum, + intEnum: IntEnum, + + @required requiredBoolean: Boolean, + @required requiredString: String, + } + + structure EntityLists { + booleans: BooleanList, + strings: StringList, + shorts: ShortList + integers: IntegerList, + longs: LongList + floats: FloatList, + doubles: DoubleList, + enums: EnumList, + intEnums: IntEnumList, + structs: StructList, + } + + structure EntityMaps { + booleans: BooleanMap, + strings: StringMap, + integers: IntegerMap, + enums: EnumMap, + intEnums: IntEnumMap, + structs: StructMap, + } + + enum Enum { + ONE = "one", + TWO = "two", + } + + intEnum IntEnum { + ONE = 1, + TWO = 2, + } + + structure Struct { + primitives: EntityPrimitives, + strings: StringList, + integer: Integer, + string: String, + enums: EnumList, + subStructs: SubStructList, + } + + structure SubStruct { + subStructPrimitives: EntityPrimitives, + } + + list BooleanList { member: Boolean } + list StringList { member: String } + list ShortList { member: Short } + list IntegerList { member: Integer } + list LongList { member: Long } + list FloatList { member: Float } + list DoubleList { member: Double } + list EnumList { member: Enum } + list IntEnumList { member: IntEnum } + list StructList { member: Struct } + list SubStructList { member: SubStruct } + map BooleanMap { key: String, value: Boolean } + map StringMap { key: String, value: String } + map IntegerMap { key: String, value: Integer } + map EnumMap { key: String, value: Enum } + map IntEnumMap { key: String, value: IntEnum } + map StructMap { key: String, value: Struct } + """.asSmithyModel() +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index 8ac936d67930e4544c3ea2561bcec6b990282be7..82532ff5ed5de1511f8cd194bb6ae46ddcbe4620 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -257,7 +257,7 @@ fun RustType.render(fullyQualified: Boolean = true): String { if (this.lifetime == "&") { "&${this.member.render(fullyQualified)}" } else { - "&${this.lifetime?.let { "'$it" } ?: ""} ${this.member.render(fullyQualified)}" + "&${this.lifetime?.let { "'$it " } ?: ""}${this.member.render(fullyQualified)}" } } is RustType.Application -> { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index 7034ed1f12acc985aedef5eb0cbb261f930eaeac..c0922525b02e672c95ae400e68f72fb1d0b6057b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -524,13 +524,22 @@ fun RustWriter.rawTemplate( */ fun docLink(docLink: String): String = docLink.replace("::r##", "::").replace("::r#", "::") +class SafeNamer { + private var n = 0 + + fun safeName(prefix: String = "var"): String { + n += 1 + return "${prefix}_$n" + } +} + class RustWriter private constructor( private val filename: String, val namespace: String, private val commentCharacter: String = "//", private val printWarning: Boolean = true, /** Insert comments indicating where code was generated */ - private val debugMode: Boolean = false, + val debugMode: Boolean = false, /** When true, automatically change all dependencies to be in the test scope */ val devDependenciesOnly: Boolean = false, ) : @@ -619,7 +628,7 @@ class RustWriter private constructor( private val preamble = mutableListOf() private val formatter = RustSymbolFormatter() - private var n = 0 + private val safeNamer = SafeNamer() init { expressionStart = '#' @@ -644,10 +653,7 @@ class RustWriter private constructor( null } - fun safeName(prefix: String = "var"): String { - n += 1 - return "${prefix}_$n" - } + fun safeName(prefix: String = "var"): String = safeNamer.safeName(prefix) fun first(preWriter: Writable) { preamble.add(preWriter) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt index ad4b07c2308ba854427a85755a0a448a1b9361aa..7b9ad233667e61e428a84c4f1becca939a472768 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt @@ -37,12 +37,11 @@ fun Writable.some(): Writable { fun Writable.isNotEmpty(): Boolean = !this.isEmpty() -operator fun Writable.plus(other: Writable): Writable { - val first = this - return writable { - rustTemplate("#{First:W}#{Second:W}", "First" to first, "Second" to other) +operator fun Writable.plus(other: Writable): Writable = + writable { + this@plus(this) + other(this) } -} /** * Helper allowing a `Iterable` to be joined together using a `String` separator.