Unverified Commit e1986785 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement the waiter matcher generator (#3571)



This PR implements Smithy waiter matcher union codegen.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarysaito1001 <awsaito@amazon.com>
parent 1ae508d0
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -102,6 +102,8 @@ object ClientRustModule {
        /** crate::types::error */
        val Error = RustModule.public("error", parent = self)
    }

    val waiters = RustModule.pubCrate("waiters")
}

class ClientModuleDocProvider(
+192 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.replaceLifetimes
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.waiters.Matcher
import software.amazon.smithy.waiters.Matcher.ErrorTypeMember
import software.amazon.smithy.waiters.Matcher.InputOutputMember
import software.amazon.smithy.waiters.Matcher.OutputMember
import software.amazon.smithy.waiters.Matcher.SuccessMember
import software.amazon.smithy.waiters.PathComparator
import java.security.MessageDigest

private typealias Scope = Array<Pair<String, Any>>

/**
 * Generates the Rust code for the Smithy waiter "matcher union".
 * See https://smithy.io/2.0/additional-specs/waiters.html#matcher-union
 */
class RustWaiterMatcherGenerator(
    private val codegenContext: ClientCodegenContext,
    private val operationName: String,
    private val inputShape: Shape,
    private val outputShape: Shape,
) {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val module = RustModule.pubCrate("matchers", ClientRustModule.waiters)
    private val inputSymbol = codegenContext.symbolProvider.toSymbol(inputShape)
    private val outputSymbol = codegenContext.symbolProvider.toSymbol(outputShape)

    fun generate(
        errorSymbol: Symbol,
        matcher: Matcher<*>,
    ): RuntimeType {
        val fnName = fnName(operationName, matcher)
        val scope =
            arrayOf(
                *preludeScope,
                "Input" to inputSymbol,
                "Output" to outputSymbol,
                "Error" to errorSymbol,
                "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(runtimeConfig),
            )
        return RuntimeType.forInlineFun(fnName, module) {
            docs("Matcher union: " + Node.printJson(matcher.toNode()))
            rustBlockTemplate("pub(crate) fn $fnName(_input: &#{Input}, _result: &#{Result}<#{Output}, #{Error}>) -> bool", *scope) {
                when (matcher) {
                    is OutputMember -> generateOutputMember(outputShape, matcher, scope)
                    is InputOutputMember -> generateInputOutputMember(matcher, scope)
                    is SuccessMember -> generateSuccessMember(matcher)
                    is ErrorTypeMember -> generateErrorTypeMember(matcher, scope)
                    else -> throw CodegenException("Unknown waiter matcher type: $matcher")
                }
            }
        }
    }

    private fun RustWriter.generateOutputMember(
        outputShape: Shape,
        matcher: OutputMember,
        scope: Scope,
    ) {
        val pathExpression = JmespathExpression.parse(matcher.value.path)
        val pathTraversal =
            RustJmespathShapeTraversalGenerator(codegenContext).generate(
                pathExpression,
                listOf(TraversalBinding.Global("_output", outputShape)),
            )

        generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope)
    }

    private fun RustWriter.generateInputOutputMember(
        matcher: InputOutputMember,
        scope: Scope,
    ) {
        val pathExpression = JmespathExpression.parse(matcher.value.path)
        val pathTraversal =
            RustJmespathShapeTraversalGenerator(codegenContext).generate(
                pathExpression,
                listOf(
                    TraversalBinding.Named("input", "_input", inputShape),
                    TraversalBinding.Named("output", "_output", outputShape),
                ),
            )

        generatePathTraversalMatcher(pathTraversal, matcher.value.expected, matcher.value.comparator, scope)
    }

    private fun RustWriter.generatePathTraversalMatcher(
        pathTraversal: GeneratedExpression,
        expected: String,
        comparatorKind: PathComparator,
        scope: Scope,
    ) {
        val comparator =
            writable {
                rust(
                    when (comparatorKind) {
                        PathComparator.ALL_STRING_EQUALS -> "value.iter().all(|s| s == ${expected.dq()})"
                        PathComparator.ANY_STRING_EQUALS -> "value.iter().any(|s| s == ${expected.dq()})"
                        PathComparator.STRING_EQUALS -> "value == ${expected.dq()}"
                        PathComparator.BOOLEAN_EQUALS ->
                            when (pathTraversal.outputType is RustType.Reference) {
                                true -> "*value == $expected"
                                else -> "value == $expected"
                            }
                        else -> throw CodegenException("Unknown path matcher comparator: $comparatorKind")
                    },
                )
            }

        rustTemplate(
            """
            fn path_traversal<'a>(_input: &'a #{Input}, _output: &'a #{Output}) -> #{Option}<#{TraversalOutput}> {
                #{traversal}
                #{Some}(${pathTraversal.identifier})
            }
            _result.as_ref()
                .ok()
                .and_then(|output| path_traversal(_input, output))
                .map(|value| #{comparator})
                .unwrap_or_default()
            """,
            *scope,
            "traversal" to pathTraversal.output,
            "TraversalOutput" to pathTraversal.outputType.replaceLifetimes("a"),
            "comparator" to comparator,
        )
    }

    private fun RustWriter.generateSuccessMember(matcher: SuccessMember) {
        rust(
            if (matcher.value) {
                "_result.is_ok()"
            } else {
                "_result.is_err()"
            },
        )
    }

    private fun RustWriter.generateErrorTypeMember(
        matcher: ErrorTypeMember,
        scope: Scope,
    ) {
        rustTemplate(
            """
            if let #{Err}(err) = _result {
                if let #{Some}(code) = #{ProvideErrorMetadata}::code(err) {
                    return code == ${matcher.value.dq()};
                }
            }
            false
            """,
            *scope,
        )
    }

    private fun fnName(
        operationName: String,
        matcher: Matcher<*>,
    ): String {
        // Smithy models don't give us anything useful to name these functions with, so just
        // SHA-256 hash the matcher JSON and truncate it to a reasonable length. This will have
        // a nice side-effect of de-duplicating identical matchers within a given operation.
        val jsonValue = Node.printJson(matcher.toNode())
        val bytes = MessageDigest.getInstance("SHA-256").digest(jsonValue.toByteArray())
        val hex = bytes.map { byte -> String.format("%02x", byte) }.joinToString("")
        return "match_${operationName.toSnakeCase()}_${hex.substring(0..16)}"
    }
}
+349 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.rust.codegen.client.smithy.generators.waiters

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.waiters.Matcher.SuccessMember

private typealias Scope = Array<Pair<String, Any>>

class RustWaiterMatcherGeneratorTest {
    class TestCase(
        codegenContext: ClientCodegenContext,
        private val rustCrate: RustCrate,
        matcherJson: String,
    ) {
        val operationShape = codegenContext.model.lookup<OperationShape>("test#TestOperation")
        val inputShape = operationShape.inputShape(codegenContext.model)
        val outputShape = operationShape.outputShape(codegenContext.model)
        val errorShape = codegenContext.model.lookup<StructureShape>("test#SomeError")
        val inputSymbol = codegenContext.symbolProvider.toSymbol(inputShape)
        val outputSymbol = codegenContext.symbolProvider.toSymbol(outputShape)
        val errorSymbol = codegenContext.symbolProvider.toSymbol(errorShape)

        val matcher = SuccessMember.fromNode(Node.parse(matcherJson))
        val matcherFn =
            RustWaiterMatcherGenerator(codegenContext, "TestOperation", inputShape, outputShape)
                .generate(errorSymbol, matcher)

        val scope =
            arrayOf(
                *preludeScope,
                "Input" to inputSymbol,
                "Output" to outputSymbol,
                "Error" to errorSymbol,
                "ErrorMetadata" to RuntimeType.errorMetadata(codegenContext.runtimeConfig),
                "matcher_fn" to matcherFn,
            )

        fun renderTest(
            name: String,
            writeTest: TestCase.() -> Writable,
        ) {
            rustCrate.lib {
                rustTemplate(
                    """
                    /// Make the unit test public and document it so that compiler
                    /// doesn't complain about dead code.
                    pub fn ${name}_test_case() {
                        #{test}
                    }
                    ##[cfg(test)]
                    ##[test]
                    fn $name() {
                        ${name}_test_case();
                    }
                    """,
                    *scope,
                    "test" to writeTest(),
                )
            }
        }
    }

    @Test
    fun tests() {
        clientIntegrationTest(testModel()) { codegenContext, rustCrate ->
            successMatcher(codegenContext, rustCrate)
            errorMatcher(codegenContext, rustCrate)
            outputPathMatcher(codegenContext, rustCrate)
            inputOutputPathMatcher(codegenContext, rustCrate)
        }
    }

    private fun testCase(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
        name: String,
        matcherJson: String,
        writeFn: RustWriter.(Scope) -> Unit,
    ) {
        TestCase(codegenContext, rustCrate, matcherJson).renderTest(name) {
            writable {
                writeFn(scope)
            }
        }
    }

    private fun successMatcher(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
    ) = testCase(
        codegenContext,
        rustCrate,
        name = "success_matcher",
        matcherJson = """{"success":true}""",
    ) { scope ->
        rustTemplate(
            """
            let input = #{Input}::builder().foo("foo").build().unwrap();
            let result = #{Ok}(#{Output}::builder().some_string("bar").build());
            assert!(#{matcher_fn}(&input, &result));

            let result = #{Err}(#{Error}::builder().message("asdf").build());
            assert!(!#{matcher_fn}(&input, &result));
            """,
            *scope,
        )
    }

    private fun errorMatcher(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
    ) = testCase(
        codegenContext,
        rustCrate,
        name = "error_matcher",
        matcherJson = """{"errorType":"SomeError"}""",
    ) { scope ->
        rustTemplate(
            """
            let input = #{Input}::builder().foo("foo").build().unwrap();
            let result = #{Ok}(#{Output}::builder().some_string("bar").build());
            assert!(!#{matcher_fn}(&input, &result));

            let result = #{Err}(
                #{Error}::builder()
                    .message("asdf")
                    .meta(#{ErrorMetadata}::builder().code("SomeOtherError").build())
                    .build()
            );
            assert!(!#{matcher_fn}(&input, &result));

            let result = #{Err}(
                #{Error}::builder()
                    .message("asdf")
                    .meta(#{ErrorMetadata}::builder().code("SomeError").build())
                    .build()
            );
            assert!(#{matcher_fn}(&input, &result));
            """,
            *scope,
        )
    }

    private fun outputPathMatcher(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
    ) {
        fun test(
            name: String,
            matcherJson: String,
            writeFn: RustWriter.(Scope) -> Unit,
        ) = testCase(codegenContext, rustCrate, name, matcherJson, writeFn)

        fun matcherJson(
            path: String,
            expected: String,
            comparator: String,
        ) = """{"output":{"path":${path.dq()}, "expected":${expected.dq()}, "comparator": ${comparator.dq()}}}"""

        test(
            "output_path_matcher_string_equals",
            matcherJson(
                path = "someString",
                expected = "expected-value",
                comparator = "stringEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_string("bar").build());
                assert!(!#{matcher_fn}(&input, &result));

                let result = #{Ok}(#{Output}::builder().some_string("expected-value").build());
                assert!(#{matcher_fn}(&input, &result));
                """,
                *scope,
            )
        }

        test(
            "output_path_matcher_bool_equals",
            matcherJson(
                path = "someBool",
                expected = "true",
                comparator = "booleanEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_bool(false).build());
                assert!(!#{matcher_fn}(&input, &result));

                let result = #{Ok}(#{Output}::builder().some_bool(true).build());
                assert!(#{matcher_fn}(&input, &result));
                """,
                *scope,
            )
        }

        test(
            "output_path_matcher_all_string_equals",
            matcherJson(
                path = "someList",
                expected = "foo",
                comparator = "allStringEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder()
                    .some_list("foo")
                    .some_list("bar")
                    .build());
                assert!(!#{matcher_fn}(&input, &result));

                let result = #{Ok}(#{Output}::builder()
                    .some_list("foo")
                    .some_list("foo")
                    .build());
                assert!(#{matcher_fn}(&input, &result));
                """,
                *scope,
            )
        }

        test(
            "output_path_matcher_any_string_equals",
            matcherJson(
                path = "someList",
                expected = "foo",
                comparator = "anyStringEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder()
                    .some_list("bar")
                    .build());
                assert!(!#{matcher_fn}(&input, &result));

                let result = #{Ok}(#{Output}::builder()
                    .some_list("bar")
                    .some_list("foo")
                    .build());
                assert!(#{matcher_fn}(&input, &result));
                """,
                *scope,
            )
        }
    }

    private fun inputOutputPathMatcher(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
    ) {
        fun test(
            name: String,
            matcherJson: String,
            writeFn: RustWriter.(Scope) -> Unit,
        ) = testCase(codegenContext, rustCrate, name, matcherJson, writeFn)

        fun matcherJson(
            path: String,
            expected: String,
            comparator: String,
        ) = """{"inputOutput":{"path":${path.dq()}, "expected":${expected.dq()}, "comparator": ${comparator.dq()}}}"""

        test(
            "input_output_path_matcher_boolean_equals",
            matcherJson(
                path = "input.foo == 'foo' && output.someString == 'bar'",
                expected = "true",
                comparator = "booleanEquals",
            ),
        ) { scope ->
            rustTemplate(
                """
                let input = #{Input}::builder().foo("foo").build().unwrap();
                let result = #{Ok}(#{Output}::builder().some_string("bar").build());
                assert!(#{matcher_fn}(&input, &result));

                let input = #{Input}::builder().foo("asdf").build().unwrap();
                assert!(!#{matcher_fn}(&input, &result));
                """,
                *scope,
            )
        }
    }

    private fun testModel() =
        """
        ${'$'}version: "2"
        namespace test

        @aws.protocols#awsJson1_0
        service TestService {
            operations: [TestOperation],
        }

        operation TestOperation {
            input: GetEntityRequest,
            output: GetEntityResponse,
            errors: [SomeError],
        }

        @error("server")
        structure SomeError {
            message: String,
        }

        structure GetEntityRequest {
            foo: String,
        }

        structure GetEntityResponse {
            someString: String,
            someBool: Boolean,
            someList: SomeList,
        }

        list SomeList {
            member: String
        }
        """.asSmithyModel()
}