Unverified Commit 1de38029 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix `@httpChecksumRequired` and idempotency tokens in the orchestrator (#2817)

This PR fixes the Smithy `@httpChecksumRequired` trait and idempotency
token auto-fill in the orchestrator implementation.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent ee5aadc2
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -8,10 +8,15 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.HttpChecksumRequiredTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
@@ -32,6 +37,22 @@ class HttpChecksumRequiredGenerator(
            throw CodegenException("HttpChecksum required cannot be applied to a streaming shape")
        }
        return when (section) {
            is OperationSection.AdditionalRuntimePlugins -> writable {
                section.addOperationRuntimePlugin(this) {
                    rustTemplate(
                        "#{HttpChecksumRequiredRuntimePlugin}",
                        "HttpChecksumRequiredRuntimePlugin" to
                            InlineDependency.forRustFile(
                                RustModule.pubCrate("client_http_checksum_required", parent = ClientRustModule.root),
                                "/inlineable/src/client_http_checksum_required.rs",
                                CargoDependency.smithyRuntimeApi(codegenContext.runtimeConfig),
                                CargoDependency.smithyTypes(codegenContext.runtimeConfig),
                                CargoDependency.Http,
                                CargoDependency.Md5,
                            ).toType().resolve("HttpChecksumRequiredRuntimePlugin"),
                    )
                }
            }
            is OperationSection.MutateRequest -> writable {
                rustTemplate(
                    """
+36 −3
Original line number Diff line number Diff line
@@ -7,27 +7,60 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.IdempotencyTokenTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.toType
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape

class IdempotencyTokenGenerator(codegenContext: CodegenContext, operationShape: OperationShape) :
    OperationCustomization() {
class IdempotencyTokenGenerator(
    codegenContext: CodegenContext,
    operationShape: OperationShape,
) : OperationCustomization() {
    private val model = codegenContext.model
    private val runtimeConfig = codegenContext.runtimeConfig
    private val symbolProvider = codegenContext.symbolProvider
    private val idempotencyTokenMember = operationShape.inputShape(model).findMemberWithTrait<IdempotencyTokenTrait>(model)
    private val inputShape = operationShape.inputShape(model)
    private val idempotencyTokenMember = inputShape.findMemberWithTrait<IdempotencyTokenTrait>(model)

    override fun section(section: OperationSection): Writable {
        if (idempotencyTokenMember == null) {
            return emptySection
        }
        val memberName = symbolProvider.toMemberName(idempotencyTokenMember)
        return when (section) {
            is OperationSection.AdditionalRuntimePlugins -> writable {
                section.addOperationRuntimePlugin(this) {
                    rustTemplate(
                        """
                        #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| {
                            let input: &mut #{Input} = input.downcast_mut().expect("correct type");
                            if input.$memberName.is_none() {
                                input.$memberName = #{Some}(token_provider.make_idempotency_token());
                            }
                        })
                        """,
                        *preludeScope,
                        "Input" to symbolProvider.toSymbol(inputShape),
                        "IdempotencyTokenRuntimePlugin" to
                            InlineDependency.forRustFile(
                                RustModule.pubCrate("client_idempotency_token", parent = ClientRustModule.root),
                                "/inlineable/src/client_idempotency_token.rs",
                                CargoDependency.smithyRuntimeApi(runtimeConfig),
                                CargoDependency.smithyTypes(runtimeConfig),
                            ).toType().resolve("IdempotencyTokenRuntimePlugin"),
                    )
                }
            }
            is OperationSection.MutateInput -> writable {
                rustTemplate(
                    """
+16 −14
Original line number Diff line number Diff line
@@ -18,22 +18,24 @@ default = ["gated-tests"]


[dependencies]
"bytes" = "1"
"http" = "0.2.1"
"aws-smithy-types" = { path = "../aws-smithy-types" }
"aws-smithy-json" = { path = "../aws-smithy-json" }
"aws-smithy-xml" = { path = "../aws-smithy-xml" }
"aws-smithy-http-server" = { path = "../aws-smithy-http-server" }
"fastrand" = "1"
"futures-util" = "0.3"
"pin-project-lite" = "0.2"
"tower" = { version = "0.4.11", default-features = false }
"async-trait" = "0.1"
percent-encoding = "2.2.0"
async-trait = "0.1"
aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-http-server = { path = "../aws-smithy-http-server" }
aws-smithy-json = { path = "../aws-smithy-json" }
aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api" }
aws-smithy-types = { path = "../aws-smithy-types" }
aws-smithy-xml = { path = "../aws-smithy-xml" }
bytes = "1"
fastrand = "1"
futures-util = "0.3"
http = "0.2.1"
md-5 = "0.10.0"
once_cell = "1.16.0"
percent-encoding = "2.2.0"
pin-project-lite = "0.2"
regex = "1.5.5"
"url" = "2.2.2"
aws-smithy-http = { path = "../aws-smithy-http" }
tower = { version = "0.4.11", default-features = false }
url = "2.2.2"

[dev-dependencies]
proptest = "1"
+48 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::{
    Interceptor, InterceptorRegistrar, SharedInterceptor,
};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::base64;
use aws_smithy_types::config_bag::ConfigBag;
use http::header::HeaderName;

#[derive(Debug)]
pub(crate) struct HttpChecksumRequiredRuntimePlugin;

impl RuntimePlugin for HttpChecksumRequiredRuntimePlugin {
    fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
        interceptors.register(SharedInterceptor::new(HttpChecksumRequiredInterceptor));
    }
}

#[derive(Debug)]
struct HttpChecksumRequiredInterceptor;

impl Interceptor for HttpChecksumRequiredInterceptor {
    fn modify_before_signing(
        &self,
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
        _cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let request = context.request_mut();
        let body_bytes = request
            .body()
            .bytes()
            .expect("checksum can only be computed for non-streaming operations");
        let checksum = <md5::Md5 as md5::Digest>::digest(body_bytes);
        request.headers_mut().insert(
            HeaderName::from_static("content-md5"),
            base64::encode(&checksum[..])
                .parse()
                .expect("checksum is a valid header value"),
        );
        Ok(())
    }
}
+66 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use crate::idempotency_token::IdempotencyTokenProvider;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::{
    BeforeSerializationInterceptorContextMut, Input,
};
use aws_smithy_runtime_api::client::interceptors::{
    Interceptor, InterceptorRegistrar, SharedInterceptor,
};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_types::config_bag::ConfigBag;
use std::fmt;

#[derive(Debug)]
pub(crate) struct IdempotencyTokenRuntimePlugin {
    interceptor: SharedInterceptor,
}

impl IdempotencyTokenRuntimePlugin {
    pub(crate) fn new<S>(set_token: S) -> Self
    where
        S: Fn(IdempotencyTokenProvider, &mut Input) + Send + Sync + 'static,
    {
        Self {
            interceptor: SharedInterceptor::new(IdempotencyTokenInterceptor { set_token }),
        }
    }
}

impl RuntimePlugin for IdempotencyTokenRuntimePlugin {
    fn interceptors(&self, interceptors: &mut InterceptorRegistrar) {
        interceptors.register(self.interceptor.clone());
    }
}

struct IdempotencyTokenInterceptor<S> {
    set_token: S,
}

impl<S> fmt::Debug for IdempotencyTokenInterceptor<S> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("IdempotencyTokenInterceptor").finish()
    }
}

impl<S> Interceptor for IdempotencyTokenInterceptor<S>
where
    S: Fn(IdempotencyTokenProvider, &mut Input) + Send + Sync,
{
    fn modify_before_serialization(
        &self,
        context: &mut BeforeSerializationInterceptorContextMut<'_>,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let token_provider = cfg
            .load::<IdempotencyTokenProvider>()
            .expect("the idempotency provider must be set")
            .clone();
        (self.set_token)(token_provider, context.input_mut());
        Ok(())
    }
}
Loading