Unverified Commit e4a58c36 authored by Landon James's avatar Landon James Committed by GitHub
Browse files

Add support for `operationContextParams` Endpoints trait (#3755)

## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
We have to support the new [`operationContextParams`
trait](https://smithy.io/2.0/additional-specs/rules-engine/parameters.html#smithy-rules-operationcontextparams-trait)
for endpoint resolution. This trait specifies JMESPath expressions for
selecting parameter data from the operation's input type.

## Description
<!--- Describe your changes in detail -->
* Add codegen support for the [JMESPath
`keys`](https://jmespath.org/specification.html#keys) function (required
by the trait
[spec](https://smithy.io/2.0/additional-specs/rules-engine/parameters.html#smithy-rules-operationcontextparams-trait))
* Add codegen support for the trait itself. This is achieved by
generating `get_param_name` functions for each param specified in
`operationContextParams`. These functions pull the data out of the input
object and it is added to the endpoint params in the
`${operationName}EndpointParamsInterceptor`

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Updated the existing test suite for JMESPath codegen to test the `keys`
function. Updated the existing EndpointsDecoratorTest with an
`operationContextParams` trait specifying one param of each supported
type to test the codegen.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 2313eb9e
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -10,3 +10,15 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[smithy-rs]]
message = "Support `stringArray` type in endpoints params"
references = ["smithy-rs#3742"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"

[[smithy-rs]]
message = "Add support for `operationContextParams` Endpoints trait"
references = ["smithy-rs#3755"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"
+7 −2
Original line number Diff line number Diff line
@@ -120,6 +120,8 @@ internal class EndpointParamsGenerator(
        fun memberName(parameterName: String) = Identifier.of(parameterName).rustName()

        fun setterName(parameterName: String) = "set_${memberName(parameterName)}"

        fun getterName(parameterName: String) = "get_${memberName(parameterName)}"
    }

    fun paramsStruct(): RuntimeType =
@@ -230,7 +232,9 @@ internal class EndpointParamsGenerator(

    private fun generateEndpointParamsBuilder(rustWriter: RustWriter) {
        rustWriter.docs("Builder for [`Params`]")
        Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(rustWriter)
        Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(
            rustWriter,
        )
        rustWriter.rustBlock("pub struct ParamsBuilder") {
            parameters.toList().forEach { parameter ->
                val name = parameter.memberName()
@@ -253,7 +257,8 @@ internal class EndpointParamsGenerator(
                        rustBlockTemplate("#{Params}", "Params" to paramsStruct()) {
                            parameters.toList().forEach { parameter ->
                                rust("${parameter.memberName()}: self.${parameter.memberName()}")
                                parameter.default.orNull()?.also { default -> rust(".or_else(||Some(${value(default)}))") }
                                parameter.default.orNull()
                                    ?.also { default -> rust(".or_else(||Some(${value(default)}))") }
                                if (parameter.isRequired) {
                                    rustTemplate(
                                        ".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?",
+74 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators

import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
@@ -20,16 +21,23 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.configParamNewtype
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.loadFromConfigBag
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.RustJmespathShapeTraversalGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversalBinding
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversedShape
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asRef
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
@@ -103,10 +111,16 @@ class EndpointParamsInterceptorGenerator(
                    #{Ok}(())
                }
            }

            // The get_* functions below are generated from JMESPath expressions in the
            // operationContextParams trait. They target the operation's input shape.

            #{jmespath_getters}
            """,
            *codegenScope,
            "endpoint_prefix" to endpointPrefix(operationShape),
            "param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
            "jmespath_getters" to jmesPathGetters(operationShape),
        )
    }

@@ -140,6 +154,33 @@ class EndpointParamsInterceptorGenerator(
            rust(".$setterName(#W)", value)
        }

        idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
            val setterName = EndpointParamsGenerator.setterName(name)
            val getterName = EndpointParamsGenerator.getterName(name)
            val pathValue = param.path
            val pathExpression = JmespathExpression.parse(pathValue)
            val pathTraversal =
                RustJmespathShapeTraversalGenerator(codegenContext).generate(
                    pathExpression,
                    listOf(
                        TraversalBinding.Global(
                            "input",
                            TraversedShape.from(model, operationShape.inputShape(model)),
                        ),
                    ),
                )

            when (pathTraversal.outputType) {
                is RustType.Vec -> {
                    rust(".$setterName($getterName(_input))")
                }

                else -> {
                    rust(".$setterName($getterName(_input).cloned())")
                }
            }
        }

        // lastly, allow these to be overridden by members
        memberParams.forEach { (memberShape, param) ->
            val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
@@ -151,6 +192,39 @@ class EndpointParamsInterceptorGenerator(
        }
    }

    private fun jmesPathGetters(operationShape: OperationShape) =
        writable {
            val idx = ContextIndex.of(codegenContext.model)
            val inputShape = operationShape.inputShape(codegenContext.model)
            val input = symbolProvider.toSymbol(inputShape)

            idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
                val getterName = EndpointParamsGenerator.getterName(name)
                val pathValue = param.path
                val pathExpression = JmespathExpression.parse(pathValue)
                val pathTraversal =
                    RustJmespathShapeTraversalGenerator(codegenContext).generate(
                        pathExpression,
                        listOf(
                            TraversalBinding.Global(
                                "input",
                                TraversedShape.from(model, operationShape.inputShape(model)),
                            ),
                        ),
                    )

                rust("// Generated from JMESPath Expression: $pathValue")
                rustBlockTemplate(
                    "fn $getterName(input: #{Input}) -> Option<#{Ret}>",
                    "Input" to input.rustType().asRef(),
                    "Ret" to pathTraversal.outputType,
                ) {
                    pathTraversal.output(this)
                    rust("Some(${pathTraversal.identifier})")
                }
            }
        }

    private fun Node.toWritable(): Writable {
        val node = this
        return writable {
+38 −1
Original line number Diff line number Diff line
@@ -128,6 +128,8 @@ data class GeneratedExpression(

    internal fun isStringOrEnum(): Boolean = isString() || isEnum()

    internal fun isObject(): Boolean = outputShape is TraversedShape.Object

    /** Dereferences this expression if it is a reference. */
    internal fun dereference(namer: SafeNamer): GeneratedExpression =
        if (outputType is RustType.Reference) {
@@ -278,7 +280,7 @@ class JmesPathTraversalCodegenBugException(msg: String?, what: Throwable? = null
 * - Object projections
 * - Multi-select lists (but only when every item in the list is the exact same type)
 * - And/or/not boolean operations
 * - Functions `contains` and `length`. The `keys` function may be supported in the future.
 * - Functions `contains`, `length`, and `keys`.
 */
class RustJmespathShapeTraversalGenerator(
    codegenContext: ClientCodegenContext,
@@ -429,6 +431,41 @@ class RustJmespathShapeTraversalGenerator(
                }
            }

            "keys" -> {
                if (expr.arguments.size != 1) {
                    throw InvalidJmesPathTraversalException("Keys function takes exactly one argument")
                }
                val arg = generate(expr.arguments[0], bindings)
                if (!arg.isObject()) {
                    throw InvalidJmesPathTraversalException("Argument to `keys` function must be an object type")
                }
                GeneratedExpression(
                    identifier = ident,
                    outputType = RustType.Vec(RustType.String),
                    outputShape = TraversedShape.Array(null, TraversedShape.String(null)),
                    output =
                        writable {
                            arg.output(this)
                            val outputShape = arg.outputShape.shape
                            when (outputShape) {
                                is StructureShape -> {
                                    // Can't iterate a struct in Rust so source the keys from smithy
                                    val keys =
                                        outputShape.allMembers.keys.joinToString(",") { "${it.dq()}.to_string()" }
                                    rust("let $ident = vec![$keys];")
                                }

                                is MapShape -> {
                                    rust("let $ident = ${arg.identifier}.keys().map(Clone::clone).collect::<Vec<String>>();")
                                }

                                else ->
                                    throw UnsupportedJmesPathException("The shape type for an input to the keys function must be a struct or a map, got ${outputShape?.type}")
                            }
                        },
                )
            }

            else -> throw UnsupportedJmesPathException("The `${expr.name}` function is not supported by smithy-rs")
        }
    }
+94 −21
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ class EndpointsDecoratorTest {

        use smithy.rules#clientContextParams
        use smithy.rules#staticContextParams
        use smithy.rules#operationContextParams
        use smithy.rules#contextParam
        use aws.protocols#awsJson1_1

@@ -64,7 +65,10 @@ class EndpointsDecoratorTest {
                "BoolBuiltInWithDefault": { "required": true, "type": "boolean", "builtIn": "AWS::FooBar", "default": true },
                "AStringParam": { "required": false, "type": "string" },
                "ABoolParam": { "required": false, "type": "boolean" },
                "AStringArrayParam": { "required": false, "type": "stringArray" }
                "AStringArrayParam": { "required": false, "type": "stringArray" },
                "JmesPathParamString": {"required": false, type: "string"},
                "JmesPathParamBoolean": {"required": false, type: "boolean"},
                "JmesPathParamStringArray": {"required": false, type: "stringArray"},
            }
        })
        @clientContextParams(
@@ -112,6 +116,17 @@ class EndpointsDecoratorTest {
            Region: { value: "us-east-2" },
            AStringArrayParam: {value: ["a", "b", "c"]}
        )
        @operationContextParams(
            JmesPathParamString: {
                path: "nested.field",
            }
            JmesPathParamBoolean: {
                path: "nested.boolField",
            }
            JmesPathParamStringArray: {
                path: "keys(nested.mapField)",
            }
        )
        operation TestOperation {
            input: TestOperationInput
        }
@@ -125,7 +140,14 @@ class EndpointsDecoratorTest {
        }

        structure NestedStructure {
            field: String
            field: String,
            boolField: Boolean,
            mapField: IntegerMap,
        }

        map IntegerMap {
            key: String,
            value: Integer
        }
        """.asSmithyModel(disableValidation = true)

@@ -157,7 +179,7 @@ class EndpointsDecoratorTest {
                            use std::time::Duration;
                            use $moduleName::{
                                config::endpoint::Params, config::interceptors::BeforeTransmitInterceptorContextRef,
                                config::Intercept, config::SharedAsyncSleep, Client, Config,
                                config::Intercept, config::SharedAsyncSleep, types::NestedStructure, Client, Config,
                            };

                            ##[derive(Clone, Debug, Default)]
@@ -178,10 +200,8 @@ class EndpointsDecoratorTest {
                                    let params = cfg
                                        .load::<EndpointResolverParams>()
                                        .expect("params set in config");
                                    let params: &Params = params.get().expect("correct type");
                                    assert_eq!(
                                        params,
                                        &Params::builder()
                                    let preset_params: &Params = params.get().expect("correct type");
                                    let manual_params: &Params = &Params::builder()
                                        .bucket("bucket-name".to_string())
                                        .built_in_with_default("some-default")
                                        .bool_built_in_with_default(true)
@@ -192,11 +212,51 @@ class EndpointsDecoratorTest {
                                            vec!["a", "b", "c"]
                                                .iter()
                                                .map(ToString::to_string)
                                                    .collect::<Vec<_>>()
                                                .collect::<Vec<_>>(),
                                        )
                                        .jmes_path_param_string_array(vec!["key2".to_string(), "key1".to_string()])
                                        .jmes_path_param_string("nested-field")
                                        .build()
                                            .unwrap()
                                        .unwrap();

                                    // The params struct for this test contains a vec sourced from the JMESPath keys function which
                                    // does not guarantee the order. Due to this we cannot compare the preset_params with the
                                    // manual_params directly, instead we must assert equlaity field by field.
                                    assert_eq!(preset_params.bucket(), manual_params.bucket());
                                    assert_eq!(preset_params.region(), manual_params.region());
                                    assert_eq!(
                                        preset_params.a_string_param(),
                                        manual_params.a_string_param()
                                    );
                                    assert_eq!(
                                        preset_params.built_in_with_default(),
                                        manual_params.built_in_with_default()
                                    );
                                    assert_eq!(
                                        preset_params.bool_built_in_with_default(),
                                        manual_params.bool_built_in_with_default()
                                    );
                                    assert_eq!(preset_params.a_bool_param(), manual_params.a_bool_param());
                                    assert_eq!(
                                        preset_params.a_string_array_param(),
                                        manual_params.a_string_array_param()
                                    );
                                    assert_eq!(
                                        preset_params.jmes_path_param_string(),
                                        manual_params.jmes_path_param_string()
                                    );
                                    assert_eq!(
                                        preset_params.jmes_path_param_boolean(),
                                        manual_params.jmes_path_param_boolean()
                                    );
                                    assert!(preset_params
                                        .jmes_path_param_string_array()
                                        .unwrap()
                                        .contains(&"key1".to_string()));
                                    assert!(preset_params
                                        .jmes_path_param_string_array()
                                        .unwrap()
                                        .contains(&"key2".to_string()));

                                    let endpoint = cfg.load::<Endpoint>().expect("endpoint set in config");
                                    assert_eq!(endpoint.url(), "https://www.us-east-2.example.com");
@@ -222,7 +282,20 @@ class EndpointsDecoratorTest {
                                .build();
                            let client = Client::from_conf(config);

                            let _ = dbg!(client.test_operation().bucket("bucket-name").send().await);
                            let _ = dbg!(
                                client
                                .test_operation()
                                .bucket("bucket-name")
                                .nested(
                                    NestedStructure::builder()
                                        .field("nested-field")
                                        .map_field("key1", 1)
                                        .map_field("key2", 2)
                                        .build()
                                )
                                .send()
                                .await
                            );
                            assert!(
                                interceptor.called.load(Ordering::Relaxed),
                                "the interceptor should have been called"
Loading