Unverified Commit 55faed73 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Add multiple binding support to Jmespath traversal generator (#3569)

In #3526, I forgot about the existence of input/output matchers in
Smithy waiters, so it only supported binding against a single global
output shape. This PR revises the generator so that it supports multiple
bindings so that it can generate an input/output matcher path traversal.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 64b9b91f
Loading
Loading
Loading
Loading
+90 −63
Original line number Diff line number Diff line
@@ -159,6 +159,33 @@ data class GeneratedExpression(
    }
}

/**
 * Identifier binding for JmesPath expressions.
 */
sealed class TraversalBinding {
    /** The name of this binding in the generated Rust code */
    abstract val rustName: String

    /** The Smithy shape behind this binding */
    abstract val shape: Shape

    /** Binds the given shape to the global namespace such that all its members are globally available */
    data class Global(
        override val rustName: String,
        override val shape: Shape,
    ) : TraversalBinding()

    /** Binds a shape to a name */
    data class Named(
        /** What this binding is referred to in JmesPath expressions */
        val jmespathName: String,
        override val rustName: String,
        override val shape: Shape,
    ) : TraversalBinding()
}

typealias TraversalBindings = List<TraversalBinding>

/**
 * Indicates a feature that's part of the JmesPath spec, but that we explicitly decided
 * not to support in smithy-rs due to the complexity of code generating it for Rust.
@@ -197,24 +224,23 @@ class RustJmespathShapeTraversalGenerator(

    fun generate(
        expr: JmespathExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        try {
            val result =
                when (expr) {
                    is ComparatorExpression -> generateComparator(expr, inputName, shape)
                    is FunctionExpression -> generateFunction(expr, inputName, shape)
                    is FieldExpression -> generateField(expr, inputName, shape)
                    is ComparatorExpression -> generateComparator(expr, bindings)
                    is FunctionExpression -> generateFunction(expr, bindings)
                    is FieldExpression -> generateField(expr, bindings)
                    is LiteralExpression -> generateLiteral(expr)
                    is MultiSelectListExpression -> generateMultiSelectList(expr, inputName, shape)
                    is AndExpression -> generateAnd(expr, inputName, shape)
                    is OrExpression -> generateOr(expr, inputName, shape)
                    is NotExpression -> generateNot(expr, inputName, shape)
                    is ObjectProjectionExpression -> generateObjectProjection(expr, inputName, shape)
                    is FilterProjectionExpression -> generateFilterProjection(expr, inputName, shape)
                    is ProjectionExpression -> generateProjection(expr, inputName, shape)
                    is Subexpression -> generateSubexpression(expr, inputName, shape)
                    is MultiSelectListExpression -> generateMultiSelectList(expr, bindings)
                    is AndExpression -> generateAnd(expr, bindings)
                    is OrExpression -> generateOr(expr, bindings)
                    is NotExpression -> generateNot(expr, bindings)
                    is ObjectProjectionExpression -> generateObjectProjection(expr, bindings)
                    is FilterProjectionExpression -> generateFilterProjection(expr, bindings)
                    is ProjectionExpression -> generateProjection(expr, bindings)
                    is Subexpression -> generateSubexpression(expr, bindings)
                    is CurrentExpression -> throw JmesPathTraversalCodegenBugException("current expression must be handled in each expression type that can have one")
                    is ExpressionTypeExpression -> throw UnsupportedJmesPathException("Expression type expressions are not supported by smithy-rs")
                    is IndexExpression -> throw UnsupportedJmesPathException("Index expressions are not supported by smithy-rs")
@@ -242,11 +268,10 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateComparator(
        expr: ComparatorExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val left = generate(expr.left, inputName, shape)
        val right = generate(expr.right, inputName, shape)
        val left = generate(expr.left, bindings)
        val right = generate(expr.right, bindings)
        return generateCompare(left, right, expr.comparator.toString())
    }

@@ -293,8 +318,7 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateFunction(
        expr: FunctionExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val ident = safeNamer.safeName("_ret")
        return when (expr.name) {
@@ -302,7 +326,7 @@ class RustJmespathShapeTraversalGenerator(
                if (expr.arguments.size != 1) {
                    throw InvalidJmesPathTraversalException("Length function takes exactly one argument")
                }
                val arg = generate(expr.arguments[0], inputName, shape)
                val arg = generate(expr.arguments[0], bindings)
                if (!arg.outputType.isCollection() && !arg.outputType.isString()) {
                    throw InvalidJmesPathTraversalException("Argument to `length` function must be a collection or string type")
                }
@@ -321,14 +345,14 @@ class RustJmespathShapeTraversalGenerator(
                if (expr.arguments.size != 2) {
                    throw InvalidJmesPathTraversalException("Contains function takes exactly two arguments")
                }
                val left = generate(expr.arguments[0], inputName, shape)
                val left = generate(expr.arguments[0], bindings)
                if (!left.outputType.isCollection() && !left.outputType.isString()) {
                    throw InvalidJmesPathTraversalException("First argument to `contains` function must be a collection or string type")
                }
                if (expr.arguments[1].isLiteralNull()) {
                    throw UnsupportedJmesPathException("Checking for null with `contains` is not supported in smithy-rs")
                }
                val right = generate(expr.arguments[1], inputName, shape)
                val right = generate(expr.arguments[1], bindings)
                if (!right.outputType.isNumber() && !right.outputType.isString() && right.outputShape?.isEnumShape != true) {
                    throw UnsupportedJmesPathException("Checking for anything other than numbers, strings, or enums in the `contains` function is not supported in smithy-rs")
                }
@@ -391,13 +415,23 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateField(
        expr: FieldExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        if (shape is StructureShape) {
        val globalBinding = bindings.find { it is TraversalBinding.Global }
        val namedBinding = bindings.find { it is TraversalBinding.Named && it.jmespathName == expr.name }
        if (namedBinding != null && namedBinding.shape is StructureShape) {
            // If there's a named binding that matches, then immediately return it
            return GeneratedExpression(
                identifier = namedBinding.rustName,
                outputShape = namedBinding.shape,
                outputType = symbolProvider.toSymbol(namedBinding.shape).rustType().asRef(),
                output = writable { },
            )
        } else if (globalBinding != null && globalBinding.shape is StructureShape) {
            // Otherwise, look in the global binding (if available)
            val member =
                shape.getMember(expr.name).orNull()
                    ?: throw InvalidJmesPathTraversalException("Member `${expr.name}` doesn't exist on ${shape.id}")
                globalBinding.shape.getMember(expr.name).orNull()
                    ?: throw InvalidJmesPathTraversalException("Member `${expr.name}` doesn't exist on ${globalBinding.shape.id}")
            val memberSym = symbolProvider.toSymbol(member)

            val target = model.expectShape(member.target)
@@ -412,15 +446,17 @@ class RustJmespathShapeTraversalGenerator(
                    writable {
                        rust(
                            if (memberSym.isOptional()) {
                                "let $ident = $inputName.${memberSym.name}.as_ref()?;"
                                "let $ident = ${globalBinding.rustName}.${memberSym.name}.as_ref()?;"
                            } else {
                                "let $ident = &$inputName.${memberSym.name};"
                                "let $ident = &${globalBinding.rustName}.${memberSym.name};"
                            },
                        )
                    },
            )
        } else {
        } else if (namedBinding != null || globalBinding != null) {
            throw InvalidJmesPathTraversalException("Cannot look up fields in non-struct shapes")
        } else {
            throw JmesPathTraversalCodegenBugException("Missing jmespath traversal binding for ${expr.name}; available bindings: $bindings")
        }
    }

@@ -459,12 +495,11 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateMultiSelectList(
        expr: MultiSelectListExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val expressions =
            expr.expressions.map { subexpr ->
                generate(subexpr, inputName, shape)
                generate(subexpr, bindings)
            }
        // If we wanted to support mixed-types, we would need to use tuples, add tuple support to RustType,
        // and update supported functions such as `contains` to operate on tuples.
@@ -489,24 +524,21 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateAnd(
        expr: AndExpression,
        inputName: String,
        shape: Shape,
    ): GeneratedExpression = generateBooleanOp(expr, "&&", inputName, shape)
        bindings: TraversalBindings,
    ): GeneratedExpression = generateBooleanOp(expr, "&&", bindings)

    private fun generateOr(
        expr: OrExpression,
        inputName: String,
        shape: Shape,
    ): GeneratedExpression = generateBooleanOp(expr, "||", inputName, shape)
        bindings: TraversalBindings,
    ): GeneratedExpression = generateBooleanOp(expr, "||", bindings)

    private fun generateBooleanOp(
        expr: BinaryExpression,
        op: String,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val left = generate(expr.left, inputName, shape)
        val right = generate(expr.right, inputName, shape)
        val left = generate(expr.left, bindings)
        val right = generate(expr.right, bindings)
        if (!left.outputType.isBool() || !right.outputType.isBool()) {
            throw UnsupportedJmesPathException("Applying the `$op` operation doesn't support non-boolean types in smithy-rs")
        }
@@ -527,10 +559,9 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateNot(
        expr: NotExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val inner = generate(expr.expression, inputName, shape)
        val inner = generate(expr.expression, bindings)
        if (!inner.outputType.isBool()) {
            throw UnsupportedJmesPathException("Negation of a non-boolean type is not supported by smithy-rs")
        }
@@ -550,8 +581,7 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateProjection(
        expr: ProjectionExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val maybeFlatten = expr.left
        if (maybeFlatten is SliceExpression) {
@@ -560,7 +590,7 @@ class RustJmespathShapeTraversalGenerator(
        if (maybeFlatten !is FlattenExpression) {
            throw UnsupportedJmesPathException("Only projection expressions with flattens are supported by smithy-rs")
        }
        val left = generate(maybeFlatten.expression, inputName, shape)
        val left = generate(maybeFlatten.expression, bindings)
        val leftTarget =
            when (val outputShape = left.outputShape) {
                is ListShape -> model.expectShape(outputShape.member.target)
@@ -573,7 +603,7 @@ class RustJmespathShapeTraversalGenerator(
            return left
        }

        val right = generate(expr.right, "v", leftTarget)
        val right = generate(expr.right, listOf(TraversalBinding.Global("v", leftTarget)))
        val projectionType = RustType.Vec(right.outputType.asRef())

        return safeNamer.safeName("_prj").let { ident ->
@@ -604,10 +634,9 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateFilterProjection(
        expr: FilterProjectionExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val left = generate(expr.left, inputName, shape)
        val left = generate(expr.left, bindings)
        if (!left.outputType.isList() && !left.outputType.isSet()) {
            throw UnsupportedJmesPathException("Filter projections can only be done on lists or sets in smithy-rs")
        }
@@ -622,10 +651,10 @@ class RustJmespathShapeTraversalGenerator(
                    output = writable {},
                )
            } else {
                generate(expr.right, "_v", leftTarget)
                generate(expr.right, listOf(TraversalBinding.Global("_v", leftTarget)))
            }

        val comparison = generate(expr.comparison, "_v", leftTarget)
        val comparison = generate(expr.comparison, listOf(TraversalBinding.Global("_v", leftTarget)))
        if (!comparison.outputType.isBool()) {
            throw InvalidJmesPathTraversalException("The filter expression comparison must result in a boolean")
        }
@@ -668,13 +697,12 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateObjectProjection(
        expr: ObjectProjectionExpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        if (expr.left is CurrentExpression) {
            throw UnsupportedJmesPathException("Object projection cannot be done on computed maps in smithy-rs")
        }
        val left = generate(expr.left, inputName, shape)
        val left = generate(expr.left, bindings)
        if (!left.outputType.isMap()) {
            throw UnsupportedJmesPathException("Object projection is only supported on map types in smithy-rs")
        }
@@ -692,7 +720,7 @@ class RustJmespathShapeTraversalGenerator(
                    output = writable {},
                )
            } else {
                generate(expr.right, "_v", leftTarget)
                generate(expr.right, listOf(TraversalBinding.Global("_v", leftTarget)))
            }

        val ident = safeNamer.safeName("_oprj")
@@ -726,11 +754,10 @@ class RustJmespathShapeTraversalGenerator(

    private fun generateSubexpression(
        expr: Subexpression,
        inputName: String,
        shape: Shape,
        bindings: TraversalBindings,
    ): GeneratedExpression {
        val left = generate(expr.left, inputName, shape)
        val right = generate(expr.right, left.identifier, left.outputShape!!)
        val left = generate(expr.left, bindings)
        val right = generate(expr.right, listOf(TraversalBinding.Global(left.identifier, left.outputShape!!)))
        return GeneratedExpression(
            identifier = right.identifier,
            outputShape = right.outputShape,
+87 −24

File changed.

Preview size limit exceeded, changes collapsed.

+11 −0
Original line number Diff line number Diff line
@@ -375,6 +375,17 @@ fun RustType.isEq(): Boolean =
        else -> false
    }

/** Recursively replaces lifetimes with the new lifetime */
fun RustType.replaceLifetimes(newLifetime: String?): RustType =
    when (this) {
        is RustType.Option -> copy(member = member.replaceLifetimes(newLifetime))
        is RustType.Vec -> copy(member = member.replaceLifetimes(newLifetime))
        is RustType.HashSet -> copy(member = member.replaceLifetimes(newLifetime))
        is RustType.HashMap -> copy(key = key.replaceLifetimes(newLifetime), member = member.replaceLifetimes(newLifetime))
        is RustType.Reference -> copy(lifetime = newLifetime)
        else -> this
    }

enum class Visibility {
    PRIVATE,
    PUBCRATE,