Unverified Commit 3f80a071 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix several waiter path matcher issues (#3593)

While implementing waiters in another branch, I discovered and fixed a
number of issues with the path matcher codegen logic. These issues were:

- Generated code for nested flatten projections failed to compile due to
the first projection producing a `Vec<Vec<&T>>` instead of `Vec<&T>`.
- Path matchers that don't use input were taking input as an argument
anyway, which results in an unnecessary clone of the input when used by
the generated waiter logic.
- The comparisons generated by `RustWaiterMatcherGenerator` would fail
to compile whenever comparing a string against an enum.

This PR fixes all these issues.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 08cb8a2e
Loading
Loading
Loading
Loading
+65 −50
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
@@ -78,10 +79,10 @@ data class GeneratedExpression(
    val output: Writable,
) {
    /** True if the type is a String, &str, or the shape is an enum shape. */
    fun isStringOrEnum(): Boolean = outputType.isString() || outputShape?.isEnumShape == true
    internal fun isStringOrEnum(): Boolean = outputType.isString() || outputShape?.isEnumShape == true

    /** Dereferences this expression if it is a reference. */
    fun dereference(namer: SafeNamer): GeneratedExpression =
    internal fun dereference(namer: SafeNamer): GeneratedExpression =
        if (outputType is RustType.Reference) {
            namer.safeName("_tmp").let { tmp ->
                copy(
@@ -99,7 +100,7 @@ data class GeneratedExpression(
        }

    /** Converts this expression into a &str. */
    fun convertToStrRef(namer: SafeNamer): GeneratedExpression =
    internal fun convertToStrRef(namer: SafeNamer): GeneratedExpression =
        if (outputType is RustType.Reference && outputType.member is RustType.Reference) {
            dereference(namer).convertToStrRef(namer)
        } else if (!outputType.isString()) {
@@ -131,7 +132,7 @@ data class GeneratedExpression(
        }

    /** Converts a number expression into a specific number type */
    fun convertToNumberPrimitive(
    internal fun convertToNumberPrimitive(
        namer: SafeNamer,
        desiredPrimitive: RustType,
    ): GeneratedExpression {
@@ -272,48 +273,7 @@ class RustJmespathShapeTraversalGenerator(
    ): GeneratedExpression {
        val left = generate(expr.left, bindings)
        val right = generate(expr.right, bindings)
        return generateCompare(left, right, expr.comparator.toString())
    }

    private fun generateCompare(
        left: GeneratedExpression,
        right: GeneratedExpression,
        op: String,
    ): GeneratedExpression =
        if (left.outputType.isDoubleReference()) {
            generateCompare(left.dereference(safeNamer), right, op)
        } else if (right.outputType.isDoubleReference()) {
            generateCompare(left, right.dereference(safeNamer), op)
        } else {
            safeNamer.safeName("_cmp").let { ident ->
                return GeneratedExpression(
                    identifier = ident,
                    outputType = RustType.Bool,
                    output =
                        if (left.isStringOrEnum() && right.isStringOrEnum()) {
                            writable {
                                val leftStr = left.convertToStrRef(safeNamer).also { it.output(this) }
                                val rightStr = right.convertToStrRef(safeNamer).also { it.output(this) }
                                rust("let $ident = ${leftStr.identifier} $op ${rightStr.identifier};")
                            }
                        } else if (left.outputType.isNumber() && right.outputType.isNumber()) {
                            writable {
                                val leftPrim =
                                    left.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) }
                                val rightPrim =
                                    right.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) }
                                rust("let $ident = ${leftPrim.identifier} $op ${rightPrim.identifier};")
                            }
                        } else if (left.outputType.isBool() && right.outputType.isBool()) {
                            left.output + right.output +
                                writable {
                                    rust("let $ident = ${left.identifier} $op ${right.identifier};")
                                }
                        } else {
                            throw UnsupportedJmesPathException("Comparison of ${left.outputType.render()} with ${right.outputType.render()} is not supported by smithy-rs")
                        },
                )
            }
        return generateCompare(safeNamer, left, right, expr.comparator.toString())
    }

    private fun generateFunction(
@@ -382,6 +342,7 @@ class RustJmespathShapeTraversalGenerator(
                                    withBlockTemplate("let $ident = ${left.identifier}.iter().any(|_v| {", "});") {
                                        val compare =
                                            generateCompare(
                                                safeNamer,
                                                GeneratedExpression(
                                                    identifier = "_v",
                                                    outputShape =
@@ -604,7 +565,14 @@ class RustJmespathShapeTraversalGenerator(
        }

        val right = generate(expr.right, listOf(TraversalBinding.Global("v", leftTarget)))
        val projectionType = RustType.Vec(right.outputType.asRef())

        // If the right expression results in a collection type, then the resulting vec will need to get flattened.
        // Otherwise, you'll get `Vec<&Vec<T>>` instead of `Vec<&T>`, which causes later projections to fail to compile.
        val (projectionType, flattenNeeded) =
            when {
                right.outputType.isCollection() -> right.outputType.stripOuter<RustType.Reference>() to true
                else -> RustType.Vec(right.outputType.asRef()) to false
            }

        return safeNamer.safeName("_prj").let { ident ->
            GeneratedExpression(
@@ -614,7 +582,8 @@ class RustJmespathShapeTraversalGenerator(
                output =
                    left.output +
                        writable {
                            rustBlock("let $ident = ${left.identifier}.iter().flat_map(") {
                            rust("let $ident = ${left.identifier}.iter()")
                            withBlock(".flat_map(|v| {", "})") {
                                rustBlockTemplate(
                                    "fn map(v: &#{Left}) -> #{Option}<#{Right}>",
                                    *preludeScope,
@@ -624,9 +593,12 @@ class RustJmespathShapeTraversalGenerator(
                                    right.output(this)
                                    rustTemplate("#{Some}(${right.identifier})", *preludeScope)
                                }
                                rust("map")
                                rust("map(v)")
                            }
                            rustTemplate(").collect::<#{Vec}<_>>();", *preludeScope)
                            if (flattenNeeded) {
                                rust(".flatten()")
                            }
                            rustTemplate(".collect::<#{Vec}<_>>();", *preludeScope)
                        },
            )
        }
@@ -767,6 +739,49 @@ class RustJmespathShapeTraversalGenerator(
    }
}

internal fun generateCompare(
    safeNamer: SafeNamer,
    left: GeneratedExpression,
    right: GeneratedExpression,
    op: String,
): GeneratedExpression =
    if (left.outputType.isDoubleReference()) {
        generateCompare(safeNamer, left.dereference(safeNamer), right, op)
    } else if (right.outputType.isDoubleReference()) {
        generateCompare(safeNamer, left, right.dereference(safeNamer), op)
    } else {
        safeNamer.safeName("_cmp").let { ident ->
            return GeneratedExpression(
                identifier = ident,
                outputType = RustType.Bool,
                output =
                    if (left.isStringOrEnum() && right.isStringOrEnum()) {
                        writable {
                            val leftStr = left.convertToStrRef(safeNamer).also { it.output(this) }
                            val rightStr = right.convertToStrRef(safeNamer).also { it.output(this) }
                            rust("let $ident = ${leftStr.identifier} $op ${rightStr.identifier};")
                        }
                    } else if (left.outputType.isNumber() && right.outputType.isNumber()) {
                        writable {
                            val leftPrim =
                                left.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) }
                            val rightPrim =
                                right.convertToNumberPrimitive(safeNamer, left.outputType).also { it.output(this) }
                            rust("let $ident = ${leftPrim.identifier} $op ${rightPrim.identifier};")
                        }
                    } else if (left.outputType.isBool() && right.outputType.isBool()) {
                        writable {
                            val leftPrim = left.dereference(safeNamer).also { it.output(this) }
                            val rightPrim = right.dereference(safeNamer).also { it.output(this) }
                            rust("let $ident = ${leftPrim.identifier} $op ${rightPrim.identifier};")
                        }
                    } else {
                        throw UnsupportedJmesPathException("Comparison of ${left.outputType.render()} with ${right.outputType.render()} is not supported by smithy-rs")
                    },
            )
        }
    }

private fun RustType.dereference(): RustType =
    if (this is RustType.Reference) {
        this.member.dereference()
+82 −15
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.SafeNamer
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.replaceLifetimes
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@@ -30,10 +31,16 @@ import software.amazon.smithy.waiters.Matcher.InputOutputMember
import software.amazon.smithy.waiters.Matcher.OutputMember
import software.amazon.smithy.waiters.Matcher.SuccessMember
import software.amazon.smithy.waiters.PathComparator
import software.amazon.smithy.waiters.Waiter
import java.security.MessageDigest

private typealias Scope = Array<Pair<String, Any>>

/** True if the waiter requires the operation input in its matcher implementation */
fun Waiter.requiresInput(): Boolean = acceptors.any { it.matcher.requiresInput() }

fun Matcher<*>.requiresInput(): Boolean = this is InputOutputMember

/**
 * Generates the Rust code for the Smithy waiter "matcher union".
 * See https://smithy.io/2.0/additional-specs/waiters.html#matcher-union
@@ -63,8 +70,16 @@ class RustWaiterMatcherGenerator(
                "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(runtimeConfig),
            )
        return RuntimeType.forInlineFun(fnName, module) {
            val inputArg =
                when {
                    matcher.requiresInput() -> "_input: &#{Input}, "
                    else -> ""
                }
            docs("Matcher union: " + Node.printJson(matcher.toNode()))
            rustBlockTemplate("pub(crate) fn $fnName(_input: &#{Input}, _result: &#{Result}<#{Output}, #{Error}>) -> bool", *scope) {
            rustBlockTemplate(
                "pub(crate) fn $fnName(${inputArg}_result: #{Result}<&#{Output}, &#{Error}>) -> bool",
                *scope,
            ) {
                when (matcher) {
                    is OutputMember -> generateOutputMember(outputShape, matcher, scope)
                    is InputOutputMember -> generateInputOutputMember(matcher, scope)
@@ -88,7 +103,13 @@ class RustWaiterMatcherGenerator(
                listOf(TraversalBinding.Global("_output", outputShape)),
            )

        generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope)
        generatePathTraversalMatcher(
            pathTraversal,
            matcher.value.expected,
            matcher.value.comparator,
            scope,
            matcher.requiresInput(),
        )
    }

    private fun RustWriter.generateInputOutputMember(
@@ -105,7 +126,13 @@ class RustWaiterMatcherGenerator(
                ),
            )

        generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope)
        generatePathTraversalMatcher(
            pathTraversal,
            matcher.value.expected,
            matcher.value.comparator,
            scope,
            matcher.requiresInput(),
        )
    }

    private fun RustWriter.generatePathTraversalMatcher(
@@ -113,34 +140,74 @@ class RustWaiterMatcherGenerator(
        expected: String,
        comparatorKind: PathComparator,
        scope: Scope,
        requiresInput: Boolean,
    ) {
        val comparator =
            writable {
                val leftIsIterString = listOf(PathComparator.ALL_STRING_EQUALS, PathComparator.ANY_STRING_EQUALS).contains(comparatorKind)
                val left =
                    GeneratedExpression(
                        identifier = "value",
                        outputType =
                            when {
                                leftIsIterString -> RustType.Reference(null, RustType.String)
                                else -> pathTraversal.outputType
                            },
                        outputShape = pathTraversal.outputShape,
                        output = writable {},
                    )
                val rightIsString = PathComparator.BOOLEAN_EQUALS != comparatorKind
                val right =
                    GeneratedExpression(
                        identifier = "right",
                        outputType =
                            when {
                                rightIsString -> RustType.Reference(null, RustType.Opaque("str"))
                                else -> RustType.Bool
                            },
                        output =
                            writable {
                                rust(
                                    "let right = " +
                                        when {
                                            rightIsString -> expected.dq()
                                            else -> expected
                                        } + ";",
                                )
                            },
                    )
                rustTemplate(
                    when (comparatorKind) {
                        PathComparator.ALL_STRING_EQUALS -> "value.iter().all(|s| s == ${expected.dq()})"
                        PathComparator.ANY_STRING_EQUALS -> "value.iter().any(|s| s == ${expected.dq()})"
                        PathComparator.STRING_EQUALS -> "value == ${expected.dq()}"
                        PathComparator.BOOLEAN_EQUALS ->
                            when (pathTraversal.outputType is RustType.Reference) {
                                true -> "*value == $expected"
                                else -> "value == $expected"
                            }
                        PathComparator.ALL_STRING_EQUALS -> "!value.is_empty() && value.iter().all(|value| { #{comparison} })"
                        PathComparator.ANY_STRING_EQUALS -> "value.iter().any(|value| { #{comparison} })"
                        PathComparator.STRING_EQUALS -> "#{comparison}"
                        PathComparator.BOOLEAN_EQUALS -> "#{comparison}"
                        else -> throw CodegenException("Unknown path matcher comparator: $comparatorKind")
                    },
                    "comparison" to
                        writable {
                            val compare = generateCompare(SafeNamer(), left, right, "==")
                            compare.output(this)
                            rust(compare.identifier)
                        },
                )
            }

        val (inputArgDecl, inputArg) =
            when {
                requiresInput -> "_input: &'a #{Input}, " to "_input, "
                else -> "" to ""
            }
        rustTemplate(
            """
            fn path_traversal<'a>(_input: &'a #{Input}, _output: &'a #{Output}) -> #{Option}<#{TraversalOutput}> {
            fn path_traversal<'a>(${inputArgDecl}_output: &'a #{Output}) -> #{Option}<#{TraversalOutput}> {
                #{traversal}
                #{Some}(${pathTraversal.identifier})
            }
            _result.as_ref()
                .ok()
                .and_then(|output| path_traversal(_input, output))
                .map(|value| #{comparator})
                .and_then(|output| path_traversal(${inputArg}output))
                .map(|value| { #{comparator} })
                .unwrap_or_default()
            """,
            *scope,
+3 −0
Original line number Diff line number Diff line
@@ -323,6 +323,9 @@ class RustJmespathShapeTraversalGeneratorTest {
            rust("assert_eq!(1, result.len());")
            rust("assert_eq!(\"test\", result[0]);")
        }
        test("nested_flattens", "lists.structs[].subStructs[].subStructPrimitives.string") {
            // it should compile
        }

        invalid("primitives.integer[]", "Left side of the flatten")
    }
+51 −21
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
@@ -52,6 +53,7 @@ class RustWaiterMatcherGeneratorTest {
                "Output" to outputSymbol,
                "Error" to errorSymbol,
                "ErrorMetadata" to RuntimeType.errorMetadata(codegenContext.runtimeConfig),
                "SomeEnum" to codegenContext.symbolProvider.toSymbol(codegenContext.model.lookup<EnumShape>("test#SomeEnum")),
                "matcher_fn" to matcherFn,
            )

@@ -115,12 +117,11 @@ class RustWaiterMatcherGeneratorTest {
    ) { scope ->
        rustTemplate(
            """
            let input = #{Input}::builder().foo("foo").build().unwrap();
            let result = #{Ok}(#{Output}::builder().some_string("bar").build());
            assert!(#{matcher_fn}(&input, &result));
            assert!(#{matcher_fn}(result.as_ref()));

            let result = #{Err}(#{Error}::builder().message("asdf").build());
            assert!(!#{matcher_fn}(&input, &result));
            assert!(!#{matcher_fn}(result.as_ref()));
            """,
            *scope,
        )
@@ -137,9 +138,8 @@ class RustWaiterMatcherGeneratorTest {
    ) { scope ->
        rustTemplate(
            """
            let input = #{Input}::builder().foo("foo").build().unwrap();
            let result = #{Ok}(#{Output}::builder().some_string("bar").build());
            assert!(!#{matcher_fn}(&input, &result));
            assert!(!#{matcher_fn}(result.as_ref()));

            let result = #{Err}(
                #{Error}::builder()
@@ -147,7 +147,7 @@ class RustWaiterMatcherGeneratorTest {
                    .meta(#{ErrorMetadata}::builder().code("SomeOtherError").build())
                    .build()
            );
            assert!(!#{matcher_fn}(&input, &result));
            assert!(!#{matcher_fn}(result.as_ref()));

            let result = #{Err}(
                #{Error}::builder()
@@ -155,7 +155,7 @@ class RustWaiterMatcherGeneratorTest {
                    .meta(#{ErrorMetadata}::builder().code("SomeError").build())
                    .build()
            );
            assert!(#{matcher_fn}(&input, &result));
            assert!(#{matcher_fn}(result.as_ref()));
            """,
            *scope,
        )
@@ -187,12 +187,11 @@ class RustWaiterMatcherGeneratorTest {
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_string("bar").build());
                assert!(!#{matcher_fn}(&input, &result));
                assert!(!#{matcher_fn}(result.as_ref()));

                let result = #{Ok}(#{Output}::builder().some_string("expected-value").build());
                assert!(#{matcher_fn}(&input, &result));
                assert!(#{matcher_fn}(result.as_ref()));
                """,
                *scope,
            )
@@ -208,12 +207,11 @@ class RustWaiterMatcherGeneratorTest {
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_bool(false).build());
                assert!(!#{matcher_fn}(&input, &result));
                assert!(!#{matcher_fn}(result.as_ref()));

                let result = #{Ok}(#{Output}::builder().some_bool(true).build());
                assert!(#{matcher_fn}(&input, &result));
                assert!(#{matcher_fn}(result.as_ref()));
                """,
                *scope,
            )
@@ -229,18 +227,17 @@ class RustWaiterMatcherGeneratorTest {
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder()
                    .some_list("foo")
                    .some_list("bar")
                    .build());
                assert!(!#{matcher_fn}(&input, &result));
                assert!(!#{matcher_fn}(result.as_ref()));

                let result = #{Ok}(#{Output}::builder()
                    .some_list("foo")
                    .some_list("foo")
                    .build());
                assert!(#{matcher_fn}(&input, &result));
                assert!(#{matcher_fn}(result.as_ref()));
                """,
                *scope,
            )
@@ -256,17 +253,41 @@ class RustWaiterMatcherGeneratorTest {
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder()
                    .some_list("bar")
                    .build());
                assert!(!#{matcher_fn}(&input, &result));
                assert!(!#{matcher_fn}(result.as_ref()));

                let result = #{Ok}(#{Output}::builder()
                    .some_list("bar")
                    .some_list("foo")
                    .build());
                assert!(#{matcher_fn}(&input, &result));
                assert!(#{matcher_fn}(result.as_ref()));
                """,
                *scope,
            )
        }

        test(
            "output_path_matcher_any_string_equals_enum",
            matcherJson(
                path = "someEnumList",
                expected = "Foo",
                comparator = "anyStringEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let result = #{Ok}(#{Output}::builder()
                    .some_enum_list(#{SomeEnum}::Bar)
                    .build());
                assert!(!#{matcher_fn}(result.as_ref()));

                let result = #{Ok}(#{Output}::builder()
                    .some_enum_list(#{SomeEnum}::Bar)
                    .some_enum_list(#{SomeEnum}::Foo)
                    .build());
                assert!(#{matcher_fn}(result.as_ref()));
                """,
                *scope,
            )
@@ -301,10 +322,10 @@ class RustWaiterMatcherGeneratorTest {
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_string("bar").build());
                assert!(#{matcher_fn}(&input, &result));
                assert!(#{matcher_fn}(&input, result.as_ref()));

                let input = #{Input}::builder().foo("asdf").build().unwrap();
                assert!(!#{matcher_fn}(&input, &result));
                assert!(!#{matcher_fn}(&input, result.as_ref()));
                """,
                *scope,
            )
@@ -340,10 +361,19 @@ class RustWaiterMatcherGeneratorTest {
            someString: String,
            someBool: Boolean,
            someList: SomeList,
            someEnumList: SomeEnumList,
        }

        list SomeList {
            member: String
        }

        enum SomeEnum {
            Foo,
            Bar,
        }
        list SomeEnumList {
            member: SomeEnum,
        }
        """.asSmithyModel()
}