Unverified Commit 5761d1fe authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Make it possible to customize presigned requests (#3385)

## Motivation and Context
This enables the following code:
```rust
    let req = client
        .get_object()
        .bucket("foo")
        .key("bar")
        .customize()
        .mutate_request(|req| {
            req.set_uri(req.uri().to_string() + "&a=b")
                .expect("failed to update URI")
        })
        .presigned(static_ps_config)
        .await
```
Previously, it wasn't possible to presign a request once customization
had begun.


## Description
- Follow the same pattern as send by introducing an additional trait,
CustomizeablePresign
- change InvokeWithStopPoint to actually _return_ the error if there is
one.

## Testing
New integration test

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates
- [x] Conditionally generate the presign method

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent bcfabede
Loading
Loading
Loading
Loading
+6 −0
Original line number Original line Diff line number Diff line
@@ -29,6 +29,12 @@ references = ["aws-sdk-rust#977", "smithy-rs#3365", "smithy-rs#3373"]
meta = { "breaking" = false, "bug" = false, "tada" = false }
meta = { "breaking" = false, "bug" = false, "tada" = false }
author = "rcoh"
author = "rcoh"


[[aws-sdk-rust]]
message = "It is now possible to send customized presigned requests. You can now call `.customize().<customizations>.presigned(...).await`. Previously, only normal requests supported customization."
references = ["smithy-rs#3385", "aws-sdk-rust#1031"]
meta = { "breaking" = false, "bug" = false, "tada" = true }
author = "rcoh"

[[smithy-rs]]
[[smithy-rs]]
message = "Added impl `Display` to Enums."
message = "Added impl `Display` to Enums."
references = ["smithy-rs#3336","smithy-rs#3391"]
references = ["smithy-rs#3336","smithy-rs#3391"]
+1 −1
Original line number Original line Diff line number Diff line
@@ -56,7 +56,7 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
                    AwsPresignedFluentBuilderMethod(codegenContext),
                    AwsPresignedFluentBuilderMethod(codegenContext),
                    AwsFluentClientDocs(codegenContext),
                    AwsFluentClientDocs(codegenContext),
                ),
                ),
        ).render(rustCrate, emptyList())
        ).render(rustCrate)
        rustCrate.withModule(ClientRustModule.client) {
        rustCrate.withModule(ClientRustModule.client) {
            AwsFluentClientExtensions(codegenContext, types).render(this)
            AwsFluentClientExtensions(codegenContext, types).render(this)
        }
        }
+107 −13
Original line number Original line Diff line number Diff line
@@ -8,6 +8,7 @@ package software.amazon.smithy.rustsdk
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ServiceShape
@@ -19,8 +20,11 @@ import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.CustomizableOperationSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.InternalTraitsModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.fluentBuilderType
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.RequestSerializerGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.RequestSerializerGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
@@ -33,13 +37,16 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.cloneOperation
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf
import software.amazon.smithy.rustsdk.traits.PresignableTrait
import software.amazon.smithy.rustsdk.traits.PresignableTrait
import kotlin.streams.toList
import kotlin.streams.toList


private val presigningTypes: List<Pair<String, Any>> =
private val presigningTypes: Array<Pair<String, Any>> =
    listOf(
    arrayOf(
        "PresignedRequest" to AwsRuntimeType.presigning().resolve("PresignedRequest"),
        "PresignedRequest" to AwsRuntimeType.presigning().resolve("PresignedRequest"),
        "PresigningConfig" to AwsRuntimeType.presigning().resolve("PresigningConfig"),
        "PresigningConfig" to AwsRuntimeType.presigning().resolve("PresigningConfig"),
    )
    )
@@ -94,6 +101,12 @@ class AwsPresigningDecorator internal constructor(
    override val name: String = "AwsPresigning"
    override val name: String = "AwsPresigning"
    override val order: Byte = ORDER
    override val order: Byte = ORDER


    private val codegenScope =
        arrayOf(
            *presigningTypes,
            *preludeScope,
        )

    /**
    /**
     * Adds presignable trait to known presignable operations and creates synthetic presignable shapes for codegen
     * Adds presignable trait to known presignable operations and creates synthetic presignable shapes for codegen
     */
     */
@@ -128,6 +141,43 @@ class AwsPresigningDecorator internal constructor(
            }
            }
        }.build()
        }.build()
    }
    }

    private fun anyPresignedShapes(ctx: ClientCodegenContext) =
        TopDownIndex.of(ctx.model).getContainedOperations(ctx.serviceShape)
            .any { presignableOperations.containsKey(it.id) }

    override fun extraSections(codegenContext: ClientCodegenContext): List<AdHocCustomization> =
        anyPresignedShapes(codegenContext).thenSingletonListOf {
            adhocCustomization<CustomizableOperationSection.CustomizableOperationImpl> {
                rustTemplate(
                    """
                    /// Sends the request and returns the response.
                    ##[allow(unused_mut)]
                    pub async fn presigned(mut self, presigning_config: #{PresigningConfig}) -> #{Result}<#{PresignedRequest}, crate::error::SdkError<E>> where
                        E: std::error::Error + #{Send} + #{Sync} + 'static,
                        B: #{CustomizablePresigned}<E>
                    {
                        self.execute(move |sender, conf|sender.presign(conf, presigning_config)).await
                    }
                    """,
                    *codegenScope,
                    "CustomizablePresigned" to CustomizablePresigned,
                )
            }
        }

    private val CustomizablePresigned =
        RuntimeType.forInlineFun("CustomizablePresigned", InternalTraitsModule) {
            rustTemplate(
                """
                pub trait CustomizablePresigned<E>: #{Send} + #{Sync} {
                    fn presign(self, config_override: crate::config::Builder, presigning_config: #{PresigningConfig}) -> BoxFuture<SendResult<#{PresignedRequest}, E>>;
                }

                """,
                *codegenScope,
            )
        }
}
}


class AwsPresignedFluentBuilderMethod(
class AwsPresignedFluentBuilderMethod(
@@ -135,14 +185,12 @@ class AwsPresignedFluentBuilderMethod(
) : FluentClientCustomization() {
) : FluentClientCustomization() {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope =
    private val codegenScope =
        (
            presigningTypes +
        arrayOf(
        arrayOf(
                    *RuntimeType.preludeScope,
            *presigningTypes,
            *preludeScope,
            "Error" to AwsRuntimeType.presigning().resolve("config::Error"),
            "Error" to AwsRuntimeType.presigning().resolve("config::Error"),
            "SdkError" to RuntimeType.sdkError(runtimeConfig),
            "SdkError" to RuntimeType.sdkError(runtimeConfig),
        )
        )
        ).toTypedArray()


    override fun section(section: FluentClientSection): Writable =
    override fun section(section: FluentClientSection): Writable =
        writable {
        writable {
@@ -159,18 +207,61 @@ class AwsPresignedFluentBuilderMethod(
                    *codegenScope,
                    *codegenScope,
                    "OpError" to section.operationErrorType,
                    "OpError" to section.operationErrorType,
                    "RawResponseType" to
                    "RawResponseType" to
                        RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::orchestrator::HttpResponse"),
                        RuntimeType.smithyRuntimeApiClient(runtimeConfig)
                            .resolve("client::orchestrator::HttpResponse"),
                ) {
                ) {
                    renderPresignedMethodBody(section)
                    renderPresignedMethodBody(section)
                    val builderName = section.operationShape.fluentBuilderType(codegenContext.symbolProvider).name
                    addDependency(implementPresignedTrait(section, builderName).dependency!!)
                }
                }
            }
            }
        }
        }


    private fun implementPresignedTrait(
        section: FluentClientSection.FluentBuilderImpl,
        builderName: String,
    ): RuntimeType {
        return RuntimeType.forInlineFun(
            "TraitImplementation",
            codegenContext.symbolProvider.moduleForBuilder(section.operationShape),
        ) {
            rustTemplate(
                """
                impl
                    crate::client::customize::internal::CustomizablePresigned<
                        #{OperationError},
                    > for $builderName
                {
                    fn presign(
                        self,
                        config_override: crate::config::Builder,
                        presigning_config: #{PresigningConfig}
                    ) -> crate::client::customize::internal::BoxFuture<
                        crate::client::customize::internal::SendResult<
                            #{PresignedRequest},
                            #{OperationError},
                        >,
                    > {
                        #{Box}::pin(async move { self.config_override(config_override).presigned(presigning_config).await })
                    }
                }
                """,
                *preludeScope,
                *presigningTypes,
                "OperationError" to section.operationErrorType,
                "SdkError" to RuntimeType.sdkError(runtimeConfig),
            )
        }
    }

    private fun RustWriter.renderPresignedMethodBody(section: FluentClientSection.FluentBuilderImpl) {
    private fun RustWriter.renderPresignedMethodBody(section: FluentClientSection.FluentBuilderImpl) {
        val presignableOp = PRESIGNABLE_OPERATIONS.getValue(section.operationShape.id)
        val presignableOp = PRESIGNABLE_OPERATIONS.getValue(section.operationShape.id)
        val operationShape =
        val operationShape =
            if (presignableOp.hasModelTransforms()) {
            if (presignableOp.hasModelTransforms()) {
                codegenContext.model.expectShape(syntheticShapeId(section.operationShape.id), OperationShape::class.java)
                codegenContext.model.expectShape(
                    syntheticShapeId(section.operationShape.id),
                    OperationShape::class.java,
                )
            } else {
            } else {
                section.operationShape
                section.operationShape
            }
            }
@@ -262,7 +353,10 @@ class AwsPresignedFluentBuilderMethod(
            it.uppercase()
            it.uppercase()
        }.let { baseName ->
        }.let { baseName ->
            "${baseName}PresigningRequestSerializer".let { name ->
            "${baseName}PresigningRequestSerializer".let { name ->
                RuntimeType.forInlineFun(name, codegenContext.symbolProvider.moduleForShape(transformedOperationShape)) {
                RuntimeType.forInlineFun(
                    name,
                    codegenContext.symbolProvider.moduleForShape(transformedOperationShape),
                ) {
                    RequestSerializerGenerator(
                    RequestSerializerGenerator(
                        codegenContext,
                        codegenContext,
                        codegenContext.protocolImpl!!,
                        codegenContext.protocolImpl!!,
+54 −0
Original line number Original line Diff line number Diff line
@@ -58,3 +58,57 @@ async fn test_presigning() {
    );
    );
    assert_eq!(presigned.headers().count(), 0);
    assert_eq!(presigned.headers().count(), 0);
}
}

#[tokio::test]
async fn test_presigning_customized() {
    let config = Config::builder()
        .credentials_provider(Credentials::for_tests_with_session_token())
        .region(Region::new("us-east-1"))
        .build();
    let client = polly::Client::from_conf(config);

    let presigned = client
        .synthesize_speech()
        .output_format(OutputFormat::Mp3)
        .text("hello, world")
        .voice_id(VoiceId::Joanna)
        .customize()
        .config_override(Config::builder().region(Region::new("us-west-1")))
        .mutate_request(|req| req.set_uri(req.uri().to_string() + "&test").expect("valid"))
        .presigned(
            PresigningConfig::builder()
                .start_time(SystemTime::UNIX_EPOCH + Duration::from_secs(1234567891))
                .expires_in(Duration::from_secs(30))
                .build()
                .unwrap(),
        )
        .await
        .expect("success");

    let uri = presigned.uri().parse::<http::Uri>().unwrap();
    let pq = uri.path_and_query().unwrap();
    let path = pq.path();
    let query = pq.query().unwrap();
    let mut query_params: Vec<&str> = query.split('&').collect();
    query_params.sort();

    assert_eq!("GET", presigned.method());
    assert_eq!("/v1/speech", path);
    assert_eq!(
        &[
            "OutputFormat=mp3",
            "Text=hello%2C%20world",
            "VoiceId=Joanna",
            "X-Amz-Algorithm=AWS4-HMAC-SHA256",
            "X-Amz-Credential=ANOTREAL%2F20090213%2Fus-west-1%2Fpolly%2Faws4_request",
            "X-Amz-Date=20090213T233131Z",
            "X-Amz-Expires=30",
            "X-Amz-Security-Token=notarealsessiontoken",
            "X-Amz-Signature=7cc39d2dfa3b8057f901b2827522790b48c6162571ed7e09c9725178c1cdd1fb",
            "X-Amz-SignedHeaders=host",
            "test",
        ][..],
        &query_params
    );
    assert_eq!(presigned.headers().count(), 0);
}
+49 −10
Original line number Original line Diff line number Diff line
@@ -5,9 +5,11 @@


use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::presigning::PresigningConfig;
use aws_sdk_s3::{Client, Config};
use aws_sdk_s3::{Client, Config};
use aws_smithy_runtime::client::http::test_util::capture_request;
use aws_smithy_runtime::client::http::test_util::capture_request;
use http::HeaderValue;
use http::HeaderValue;
use std::time::{Duration, SystemTime};


#[tokio::test]
#[tokio::test]
async fn test_s3_ops_are_customizable() {
async fn test_s3_ops_are_customizable() {
@@ -24,14 +26,16 @@ async fn test_s3_ops_are_customizable() {


    // The response from the fake connection won't return the expected XML but we don't care about
    // The response from the fake connection won't return the expected XML but we don't care about
    // that error in this test
    // that error in this test
    let _ = client
    let _ = assert_send(
        client
            .list_buckets()
            .list_buckets()
            .customize()
            .customize()
            .mutate_request(|req| {
            .mutate_request(|req| {
                req.headers_mut()
                req.headers_mut()
                    .append("test-header", HeaderValue::from_static("test-value"));
                    .append("test-header", HeaderValue::from_static("test-value"));
            })
            })
        .send()
            .send(),
    )
    .await
    .await
    .expect_err("this will fail due to not receiving a proper XML response.");
    .expect_err("this will fail due to not receiving a proper XML response.");


@@ -44,3 +48,38 @@ async fn test_s3_ops_are_customizable() {


    assert_eq!("test-value", test_header);
    assert_eq!("test-value", test_header);
}
}

#[tokio::test]
async fn customized_presigning() {
    let creds = Credentials::for_tests_with_session_token();
    let config = Config::builder()
        .credentials_provider(creds)
        .region(Region::new("us-east-1"))
        .build();
    let client = Client::from_conf(config);
    let static_ps_config = PresigningConfig::builder()
        .start_time(SystemTime::UNIX_EPOCH + Duration::from_secs(1234567891))
        .expires_in(Duration::from_secs(30))
        .build()
        .unwrap();
    let req = assert_send(
        client
            .get_object()
            .bucket("foo")
            .key("bar")
            .customize()
            .mutate_request(|req| {
                req.set_uri(req.uri().to_string() + "&a=b")
                    .expect("failed to update URI")
            })
            .presigned(static_ps_config),
    )
    .await
    .unwrap();
    let expect = "https://foo.s3.us-east-1.amazonaws.com/bar?x-id=GetObject&a=b&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ANOTREAL%2F20090213%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20090213T233131Z&X-Amz-Expires=30&X-Amz-SignedHeaders=host&X-Amz-Signature=2e1a459c206932ce53beb07028c711cf70f3a61dc876c6f9ce0aed5823f60234&X-Amz-Security-Token=notarealsessiontoken";
    assert_eq!(req.uri(), expect);
}

fn assert_send<T: Send>(t: T) -> T {
    t
}
Loading