Unverified Commit 9361aa52 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Generate endpoint params in the orchestrator codegen (#2658)

## Motivation and Context
This PR adds the codegen logic to generate endpoint parameters in the
endpoint params interceptor.

Fixes #2644.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 12338ad5
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@
use aws_http::user_agent::AwsUserAgent;
use aws_runtime::invocation_id::InvocationId;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::endpoint::Params;
use aws_sdk_s3::Client;
use aws_smithy_client::dvr;
use aws_smithy_client::dvr::MediaType;
@@ -31,7 +30,6 @@ async fn sra_test() {
        .build();
    let client = Client::from_conf(config);
    let fixup = FixupPlugin {
        client: client.clone(),
        timestamp: UNIX_EPOCH + Duration::from_secs(1624036048),
    };

@@ -52,7 +50,6 @@ async fn sra_test() {
}

struct FixupPlugin {
    client: Client,
    timestamp: SystemTime,
}
impl RuntimePlugin for FixupPlugin {
@@ -60,11 +57,6 @@ impl RuntimePlugin for FixupPlugin {
        &self,
        cfg: &mut ConfigBag,
    ) -> Result<(), aws_smithy_runtime_api::client::runtime_plugin::BoxError> {
        let params_builder = Params::builder()
            .set_region(self.client.conf().region().map(|c| c.as_ref().to_string()))
            .bucket("test-bucket");

        cfg.put(params_builder);
        cfg.set_request_time(RequestTime::new(self.timestamp.clone()));
        cfg.put(AwsUserAgent::for_tests());
        cfg.put(InvocationId::for_tests());
+0 −4
Original line number Diff line number Diff line
@@ -47,10 +47,6 @@ private class EndpointParametersRuntimePluginCustomization(
            section.registerInterceptor(codegenContext.runtimeConfig, this) {
                rust("${operationName}EndpointParamsInterceptor")
            }
            // The finalizer interceptor should be registered last
            section.registerInterceptor(codegenContext.runtimeConfig, this) {
                rust("${operationName}EndpointParamsFinalizerInterceptor")
            }
        }
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -101,6 +101,7 @@ class EndpointsDecorator : ClientCodegenDecorator {
    override val name: String = "Endpoints"
    override val order: Byte = 0

    // TODO(enableNewSmithyRuntime): Remove `operationCustomizations` and `InjectEndpointInMakeOperation`
    override fun operationCustomizations(
        codegenContext: ClientCodegenContext,
        operation: OperationShape,
+78 −57
Original line number Diff line number Diff line
@@ -5,10 +5,17 @@

package software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators

import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.traits.EndpointTrait
import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters
import software.amazon.smithy.rulesengine.traits.ContextIndex
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
@@ -17,13 +24,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.orNull

class EndpointParamsInterceptorGenerator(
    private val codegenContext: ClientCodegenContext,
) {
    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext)
    private val codegenScope = codegenContext.runtimeConfig.let { rc ->
        val endpointTypesGenerator = EndpointTypesGenerator.fromContext(codegenContext)
        val runtimeApi = CargoDependency.smithyRuntimeApi(rc).toType()
@@ -34,29 +45,19 @@ class EndpointParamsInterceptorGenerator(
            "ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"),
            "ContextAttachedError" to interceptors.resolve("error::ContextAttachedError"),
            "EndpointResolverParams" to orchestrator.resolve("EndpointResolverParams"),
            "HttpResponse" to orchestrator.resolve("HttpResponse"),
            "HttpRequest" to orchestrator.resolve("HttpRequest"),
            "HttpResponse" to orchestrator.resolve("HttpResponse"),
            "Interceptor" to interceptors.resolve("Interceptor"),
            "InterceptorContext" to interceptors.resolve("InterceptorContext"),
            "InterceptorError" to interceptors.resolve("error::InterceptorError"),
            "ParamsBuilder" to endpointTypesGenerator.paramsBuilder(),
            "Params" to endpointTypesGenerator.paramsStruct(),
        )
    }

    fun render(writer: RustWriter, operationShape: OperationShape) {
        val operationName = symbolProvider.toSymbol(operationShape).name
        renderInterceptor(
            writer,
            "${operationName}EndpointParamsInterceptor",
            implInterceptorBodyForEndpointParams(operationShape),
        )
        renderInterceptor(
            writer, "${operationName}EndpointParamsFinalizerInterceptor",
            implInterceptorBodyForEndpointParamsFinalizer,
        )
    }

    private fun renderInterceptor(writer: RustWriter, interceptorName: String, implBody: Writable) {
        val operationInput = symbolProvider.toSymbol(operationShape.inputShape(model))
        val interceptorName = "${operationName}EndpointParamsInterceptor"
        writer.rustTemplate(
            """
            ##[derive(Debug)]
@@ -68,39 +69,80 @@ class EndpointParamsInterceptorGenerator(
                    context: &#{InterceptorContext}<#{HttpRequest}, #{HttpResponse}>,
                    cfg: &mut #{ConfigBag},
                ) -> Result<(), #{BoxError}> {
                    #{body:W}
                }
            }
            """,
            *codegenScope,
            "body" to implBody,
        )
    }

    private fun implInterceptorBodyForEndpointParams(operationShape: OperationShape): Writable = writable {
        val operationInput = symbolProvider.toSymbol(operationShape.inputShape(model))
        rustTemplate(
            """
            let input = context.input()?;
            let _input = input
                    let _input = context.input()?;
                    let _input = _input
                        .downcast_ref::<${operationInput.name}>()
                        .ok_or("failed to downcast to ${operationInput.name}")?;
            let params_builder = cfg
                .get::<#{ParamsBuilder}>()
                .ok_or("missing endpoint params builder")?
                .clone();
            ${"" /* TODO(EndpointResolver): Call setters on `params_builder` to update its fields by using values from `_input` */}
            cfg.put(params_builder);

                    #{endpoint_prefix:W}

                    // HACK: pull the handle out of the config bag until config is implemented right
                    let handle = cfg.get::<std::sync::Arc<crate::client::Handle>>()
                        .expect("the handle is hacked into the config bag");
                    let _config = &handle.conf;

                    let params = #{Params}::builder()
                        #{param_setters}
                        .build()
                        .map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
                    cfg.put(#{EndpointResolverParams}::new(params));
                    Ok(())
                }
            }
            """,
            *codegenScope,
            "endpoint_prefix" to endpointPrefix(operationShape),
            "param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
        )
    }

    private fun paramSetters(operationShape: OperationShape, params: Parameters) = writable {
        val idx = ContextIndex.of(codegenContext.model)
        val memberParams = idx.getContextParams(operationShape).toList().sortedBy { it.first.memberName }
        val builtInParams = params.toList().filter { it.isBuiltIn }
        // first load builtins and their defaults
        builtInParams.forEach { param ->
            endpointTypesGenerator.builtInFor(param, "_config")?.also { defaultValue ->
                rust(".set_${param.name.rustName()}(#W)", defaultValue)
            }
        }

        idx.getClientContextParams(codegenContext.serviceShape).orNull()?.parameters?.forEach { (name, param) ->
            val paramName = EndpointParamsGenerator.memberName(name)
            val setterName = EndpointParamsGenerator.setterName(name)
            if (param.type == ShapeType.BOOLEAN) {
                rust(".$setterName(_config.$paramName)")
            } else {
                rust(".$setterName(_config.$paramName.clone())")
            }
        }

        idx.getStaticContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
            val setterName = EndpointParamsGenerator.setterName(name)
            val value = param.value.toWritable()
            rust(".$setterName(#W)", value)
        }

        // lastly, allow these to be overridden by members
        memberParams.forEach { (memberShape, param) ->
            val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
            rust(
                ".${EndpointParamsGenerator.setterName(param.name)}(_input.$memberName.clone())",
            )
        }
    }

    private fun Node.toWritable(): Writable {
        val node = this
        return writable {
            when (node) {
                is StringNode -> rust("Some(${node.value.dq()}.to_string())")
                is BooleanNode -> rust("Some(${node.value})")
                else -> PANIC("unsupported default value: $node")
            }
        }
    }

    private fun endpointPrefix(operationShape: OperationShape): Writable = writable {
        operationShape.getTrait(EndpointTrait::class.java).map { epTrait ->
            val endpointTraitBindings = EndpointTraitBindings(
@@ -124,25 +166,4 @@ class EndpointParamsInterceptorGenerator(
            rust("cfg.put(endpoint_prefix);")
        }
    }

    private val implInterceptorBodyForEndpointParamsFinalizer: Writable = writable {
        rustTemplate(
            """
            let _ = context;
            let params_builder = cfg
                .get::<#{ParamsBuilder}>()
                .ok_or("missing endpoint params builder")?
                .clone();
            let params = params_builder
                .build()
                .map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;
            cfg.put(
                #{EndpointResolverParams}::new(params)
            );

            Ok(())
            """,
            *codegenScope,
        )
    }
}
+4 −4
Original line number Diff line number Diff line
@@ -70,7 +70,6 @@ class ServiceRuntimePluginGenerator(
        val runtimeApi = RuntimeType.smithyRuntimeApi(rc)
        arrayOf(
            "AnonymousIdentityResolver" to runtimeApi.resolve("client::identity::AnonymousIdentityResolver"),
            "StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
            "BoxError" to runtimeApi.resolve("client::runtime_plugin::BoxError"),
            "ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"),
            "ConfigBagAccessors" to runtimeApi.resolve("client::orchestrator::ConfigBagAccessors"),
@@ -85,6 +84,7 @@ class ServiceRuntimePluginGenerator(
            "ResolveEndpoint" to http.resolve("endpoint::ResolveEndpoint"),
            "RuntimePlugin" to runtimeApi.resolve("client::runtime_plugin::RuntimePlugin"),
            "SharedEndpointResolver" to http.resolve("endpoint::SharedEndpointResolver"),
            "StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
            "TraceProbe" to runtimeApi.resolve("client::orchestrator::TraceProbe"),
        )
    }
@@ -106,6 +106,9 @@ class ServiceRuntimePluginGenerator(
                fn configure(&self, cfg: &mut #{ConfigBag}) -> Result<(), #{BoxError}> {
                    use #{ConfigBagAccessors};

                    // HACK: Put the handle into the config bag to work around config not being fully implemented yet
                    cfg.put(self.handle.clone());

                    let http_auth_schemes = #{HttpAuthSchemes}::builder()
                        #{http_auth_scheme_customizations}
                        .build();
@@ -118,9 +121,6 @@ class ServiceRuntimePluginGenerator(
                        #{SharedEndpointResolver}::from(self.handle.conf.endpoint_resolver()));
                    cfg.set_endpoint_resolver(endpoint_resolver);

                    ${"" /* TODO(EndpointResolver): Create endpoint params builder from service config */}
                    cfg.put(#{Params}::builder());

                    // TODO(RuntimePlugins): Wire up standard retry
                    cfg.set_retry_strategy(#{NeverRetryStrategy}::new());