Unverified Commit a694dd9d authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix Polly presigning in the orchestrator (#2769)

This PR fixes presigning for Amazon Polly in the orchestrator
implementation.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 9d56db1f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client" }
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-http-tower = { path = "../../../rust-runtime/aws-smithy-http-tower" }
aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api" }
aws-smithy-runtime = { path = "../../../rust-runtime/aws-smithy-runtime" }
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
aws-types = { path = "../aws-types" }
@@ -42,7 +43,6 @@ tracing = "0.1"
aws-credential-types = { path = "../aws-credential-types", features = ["test-util"] }
aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client", features = ["test-util"] }
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http", features = ["rt-tokio"] }
aws-smithy-runtime = { path = "../../../rust-runtime/aws-smithy-runtime" }
tempfile = "3.6.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async", features = ["test-util"] }
+15 −6
Original line number Diff line number Diff line
@@ -11,7 +11,9 @@ use aws_runtime::auth::sigv4::{HttpSignatureType, SigV4OperationSigningConfig};
use aws_runtime::invocation_id::InvocationIdInterceptor;
use aws_runtime::request_info::RequestInfoInterceptor;
use aws_runtime::user_agent::UserAgentInterceptor;
use aws_sigv4::http_request::SignableBody;
use aws_smithy_async::time::{SharedTimeSource, StaticTimeSource};
use aws_smithy_runtime::client::retries::strategy::NeverRetryStrategy;
use aws_smithy_runtime_api::client::interceptors::{
    disable_interceptor, BeforeSerializationInterceptorContextMut,
    BeforeTransmitInterceptorContextMut, BoxError, Interceptor, InterceptorRegistrar,
@@ -26,11 +28,15 @@ use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
#[derive(Debug)]
pub(crate) struct SigV4PresigningInterceptor {
    config: PresigningConfig,
    payload_override: SignableBody<'static>,
}

impl SigV4PresigningInterceptor {
    pub(crate) fn new(config: PresigningConfig) -> Self {
        Self { config }
    pub(crate) fn new(config: PresigningConfig, payload_override: SignableBody<'static>) -> Self {
        Self {
            config,
            payload_override,
        }
    }
}

@@ -60,8 +66,7 @@ impl Interceptor for SigV4PresigningInterceptor {
        if let Some(mut config) = cfg.get::<SigV4OperationSigningConfig>().cloned() {
            config.signing_options.expires_in = Some(self.config.expires());
            config.signing_options.signature_type = HttpSignatureType::HttpRequestQueryParams;
            config.signing_options.payload_override =
                Some(aws_sigv4::http_request::SignableBody::UnsignedPayload);
            config.signing_options.payload_override = Some(self.payload_override.clone());
            cfg.interceptor_state()
                .put::<SigV4OperationSigningConfig>(config);
            Ok(())
@@ -81,9 +86,12 @@ pub(crate) struct SigV4PresigningRuntimePlugin {
}

impl SigV4PresigningRuntimePlugin {
    pub(crate) fn new(config: PresigningConfig) -> Self {
    pub(crate) fn new(config: PresigningConfig, payload_override: SignableBody<'static>) -> Self {
        Self {
            interceptor: SharedInterceptor::new(SigV4PresigningInterceptor::new(config)),
            interceptor: SharedInterceptor::new(SigV4PresigningInterceptor::new(
                config,
                payload_override,
            )),
        }
    }
}
@@ -91,6 +99,7 @@ impl SigV4PresigningRuntimePlugin {
impl RuntimePlugin for SigV4PresigningRuntimePlugin {
    fn config(&self) -> Option<FrozenLayer> {
        let mut layer = Layer::new("Presigning");
        layer.set_retry_strategy(NeverRetryStrategy::new());
        layer.put(disable_interceptor::<InvocationIdInterceptor>("presigning"));
        layer.put(disable_interceptor::<RequestInfoInterceptor>("presigning"));
        layer.put(disable_interceptor::<UserAgentInterceptor>("presigning"));
+78 −5
Original line number Diff line number Diff line
@@ -23,17 +23,20 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSec
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.RequestSerializerGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientHttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@@ -305,14 +308,24 @@ class AwsPresignedFluentBuilderMethod(
    }

    private fun RustWriter.renderPresignedMethodBody(section: FluentClientSection.FluentBuilderImpl) {
        val presignableOp = PRESIGNABLE_OPERATIONS.getValue(section.operationShape.id)
        val operationShape = if (presignableOp.hasModelTransforms()) {
            codegenContext.model.expectShape(syntheticShapeId(section.operationShape.id), OperationShape::class.java)
        } else {
            section.operationShape
        }

        rustTemplate(
            """
            #{alternate_presigning_serializer}

            let runtime_plugins = #{Operation}::register_runtime_plugins(
                #{RuntimePlugins}::new()
                    .with_client_plugin(#{SigV4PresigningRuntimePlugin}::new(presigning_config)),
                #{RuntimePlugins}::new(),
                self.handle.clone(),
                self.config_override,
            );
                self.config_override
            )
                .with_client_plugin(#{SigV4PresigningRuntimePlugin}::new(presigning_config, #{payload_override}))
                #{alternate_presigning_serializer_registration};

            let input = self.inner.build().map_err(#{SdkError}::construction_failure)?;
            let mut context = #{Operation}::orchestrate_with_stop_point(&runtime_plugins, input, #{StopPoint}::BeforeTransmit)
@@ -332,14 +345,74 @@ class AwsPresignedFluentBuilderMethod(
            "OperationError" to section.operationErrorType,
            "RuntimePlugins" to RuntimeType.smithyRuntimeApi(runtimeConfig)
                .resolve("client::runtime_plugin::RuntimePlugins"),
            "SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors").resolve("SharedInterceptor"),
            "SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors")
                .resolve("SharedInterceptor"),
            "SigV4PresigningRuntimePlugin" to AwsRuntimeType.presigningInterceptor(runtimeConfig)
                .resolve("SigV4PresigningRuntimePlugin"),
            "StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"),
            "TypedBox" to RuntimeType.smithyTypes(runtimeConfig).resolve("type_erasure::TypedBox"),
            "USER_AGENT" to CargoDependency.Http.toType().resolve("header::USER_AGENT"),
            "alternate_presigning_serializer" to writable {
                if (presignableOp.hasModelTransforms()) {
                    val smithyTypes = RuntimeType.smithyTypes(codegenContext.runtimeConfig)
                    rustTemplate(
                        """
                        ##[derive(Debug)]
                        struct AlternatePresigningSerializerRuntimePlugin;
                        impl #{RuntimePlugin} for AlternatePresigningSerializerRuntimePlugin {
                            fn config(&self) -> Option<#{FrozenLayer}> {
                                use #{ConfigBagAccessors};
                                let mut cfg = #{Layer}::new("presigning_serializer");
                                cfg.set_request_serializer(#{AlternateSerializer});
                                Some(cfg.freeze())
                            }
                        }
                        """,
                        "AlternateSerializer" to alternateSerializer(operationShape),
                        "ConfigBagAccessors" to RuntimeType.smithyRuntimeApi(codegenContext.runtimeConfig)
                            .resolve("client::orchestrator::ConfigBagAccessors"),
                        "FrozenLayer" to smithyTypes.resolve("config_bag::FrozenLayer"),
                        "Layer" to smithyTypes.resolve("config_bag::Layer"),
                        "RuntimePlugin" to RuntimeType.runtimePlugin(codegenContext.runtimeConfig),
                    )
                }
            },
            "alternate_presigning_serializer_registration" to writable {
                if (presignableOp.hasModelTransforms()) {
                    rust(".with_operation_plugin(AlternatePresigningSerializerRuntimePlugin)")
                }
            },
            "payload_override" to writable {
                rustTemplate(
                    "#{aws_sigv4}::http_request::SignableBody::" +
                        when (presignableOp.payloadSigningType) {
                            PayloadSigningType.EMPTY -> "Bytes(b\"\")"
                            PayloadSigningType.UNSIGNED_PAYLOAD -> "UnsignedPayload"
                        },
                    "aws_sigv4" to AwsRuntimeType.awsSigv4(runtimeConfig),
                )
            },
        )
    }

    private fun alternateSerializer(transformedOperationShape: OperationShape): RuntimeType =
        transformedOperationShape.contextName(codegenContext.serviceShape).replaceFirstChar {
            it.uppercase()
        }.let { baseName ->
            "${baseName}PresigningRequestSerializer".let { name ->
                RuntimeType.forInlineFun(name, codegenContext.symbolProvider.moduleForShape(transformedOperationShape)) {
                    RequestSerializerGenerator(
                        codegenContext,
                        codegenContext.protocolImpl!!,
                        null,
                        nameOverride = name,
                    ).render(
                        this,
                        transformedOperationShape,
                    )
                }
            }
        }
}

interface PresignModelTransform {
+2 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol

/**
 * [ClientCodegenContext] contains code-generation context that is _specific_ to the [RustClientCodegenPlugin] plugin
@@ -30,6 +31,7 @@ data class ClientCodegenContext(
    // Expose the `rootDecorator`, enabling customizations to compose by referencing information from the root codegen
    // decorator
    val rootDecorator: ClientCodegenDecorator,
    val protocolImpl: Protocol? = null,
) : CodegenContext(
    model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT,
) {
+1 −0
Original line number Diff line number Diff line
@@ -108,6 +108,7 @@ class ClientCodegenVisitor(
                codegenContext,
                ClientModuleDocProvider(codegenContext, service.serviceNameOrDefault("the service")),
            ),
            protocolImpl = protocolGeneratorFactory.protocol(codegenContext),
        )

        rustCrate = RustCrate(
Loading