Unverified Commit 1cecc3ba authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

Fix request compression (#3820)



## Description
<!--- Describe your changes in detail -->
This PR includes several fixes for request compression:
- `aws_smithy_compression::body::compress::CompressedBody` will no
longer incorrectly return the inner body's `SizeHint`, returning
`SizeHint::default()` instead.
- Fixed a bug where compressed payloads would have an incorrect content
length, causing those requests to hang.
- Compress in-memory request payloads instead of the previous lazy
approach.

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Tests are included. Additionally, I ran a reproducer for a user issues
to ensure that their use case works:

```
#[tokio::test]
async fn use_case_reproducer() {
    tracing_subscriber::fmt::init();
    let shared_config = aws_config::from_env()
        .region(aws_sdk_cloudwatch::config::Region::new("us-west-2"))
        .load()
        .await;
    let service_config = aws_sdk_cloudwatch::config::Config::from(&shared_config)
        .to_builder()
        .request_min_compression_size_bytes(1)
        .build();
    let client = Client::from_conf(service_config);

    tracing::info!("sending metrics...");
    client
        .put_metric_data()
        .namespace("CloudWatchTestMetricsBrivinc")
        .metric_data(
            aws_sdk_cloudwatch::types::MetricDatum::builder()
                .metric_name(format!("MyMetricNameIsALittleLong"))
                .value(0.0)
                .build(),
        )
        .send()
        .await
        .unwrap();
}
``` 

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarysaito1001 <awsaito@amazon.com>
parent ec226c02
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
---
applies_to: ["client", "aws-sdk-rust"]
authors: ["Velfi"]
references: ["smithy-rs#3820"]
breaking: false
new_feature: false
bug_fix: true
---

Fixed a bug with the content length of compressed payloads that caused such requests to hang.
+5 −5
Original line number Diff line number Diff line
@@ -149,7 +149,7 @@ dependencies = [

[[package]]
name = "aws-runtime"
version = "1.4.2"
version = "1.4.3"
dependencies = [
 "arbitrary",
 "aws-credential-types",
@@ -195,7 +195,7 @@ version = "0.60.3"

[[package]]
name = "aws-sigv4"
version = "1.2.3"
version = "1.2.4"
dependencies = [
 "aws-credential-types",
 "aws-smithy-eventstream",
@@ -259,7 +259,7 @@ dependencies = [

[[package]]
name = "aws-smithy-eventstream"
version = "0.60.4"
version = "0.60.5"
dependencies = [
 "aws-smithy-types",
 "bytes",
@@ -268,7 +268,7 @@ dependencies = [

[[package]]
name = "aws-smithy-http"
version = "0.60.10"
version = "0.60.11"
dependencies = [
 "aws-smithy-runtime-api",
 "aws-smithy-types",
@@ -343,7 +343,7 @@ dependencies = [

[[package]]
name = "aws-smithy-types"
version = "1.2.5"
version = "1.2.6"
dependencies = [
 "base64-simd",
 "bytes",
+5 −5
Original line number Diff line number Diff line
@@ -89,7 +89,7 @@ dependencies = [

[[package]]
name = "aws-runtime"
version = "1.4.2"
version = "1.4.3"
dependencies = [
 "aws-credential-types",
 "aws-sigv4",
@@ -173,7 +173,7 @@ dependencies = [

[[package]]
name = "aws-sigv4"
version = "1.2.3"
version = "1.2.4"
dependencies = [
 "aws-credential-types",
 "aws-smithy-http",
@@ -203,7 +203,7 @@ dependencies = [

[[package]]
name = "aws-smithy-http"
version = "0.60.10"
version = "0.60.11"
dependencies = [
 "aws-smithy-runtime-api",
 "aws-smithy-types",
@@ -298,7 +298,7 @@ dependencies = [

[[package]]
name = "aws-smithy-types"
version = "1.2.5"
version = "1.2.6"
dependencies = [
 "base64-simd",
 "bytes",
@@ -319,7 +319,7 @@ dependencies = [

[[package]]
name = "aws-smithy-xml"
version = "0.60.8"
version = "0.60.9"
dependencies = [
 "xmlparser",
]
+1 −1
Original line number Diff line number Diff line
[package]
name = "aws-config"
version = "1.5.6"
version = "1.5.7"
authors = [
    "AWS Rust SDK Team <aws-sdk-rust@amazon.com>",
    "Russell Cohen <rcoh@amazon.com>",
+201 −64
Original line number Diff line number Diff line
@@ -15,11 +15,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.integrationTest

class HttpRequestCompressionDecoratorTest {
    companion object {
        // Can't use the dollar sign in a multiline string with doing it like this.
        private const val PREFIX = "\$version: \"2\""
        val model =
            """
            $PREFIX
            namespace test

            use aws.api#service
@@ -100,7 +97,7 @@ class HttpRequestCompressionDecoratorTest {

            @output
            structure SomeIncompressibleOutput {}
            """.asSmithyModel()
            """.asSmithyModel(smithyVersion = "2.0")
    }

    @Test
@@ -147,14 +144,22 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.not_a_compressible_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .not_a_compressible_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set.
                        assert_eq!(None, request.headers().get("content-encoding"));
                        assert_eq!(None, request.headers().get(#{http}::header::CONTENT_ENCODING));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -168,14 +173,25 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_ne!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -190,14 +206,25 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is set to "gzip"
                        assert_eq!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_eq!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was compressed
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice())
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -212,14 +239,25 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_ne!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -233,14 +271,25 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is set to "gzip"
                        assert_eq!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_eq!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was compressed
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice())
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -254,14 +303,25 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_ne!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
@@ -277,7 +337,7 @@ class HttpRequestCompressionDecoratorTest {
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                            // ByteStreams created from a file are streaming and have a known size
                        // ByteStreams created from a file are streaming
                        let mut file = #{tempfile}::NamedTempFile::new().unwrap();
                        use std::io::Write;
                        file.write_all(UNCOMPRESSED_INPUT).unwrap();
@@ -285,30 +345,107 @@ class HttpRequestCompressionDecoratorTest {
                        let body = ByteStream::read_from()
                            .path(file.path())
                            .buffer_size(1024)
                                .length(#{Length}::Exact(UNCOMPRESSED_INPUT.len() as u64))
                            .length(::aws_smithy_types::byte_stream::Length::Exact(
                                UNCOMPRESSED_INPUT.len() as u64,
                            ))
                            .build()
                            .await
                            .unwrap();
                            let _ = client
                                .some_streaming_operation()
                                .body(body)
                                .send()
                                .await;
                        let _ = client.some_streaming_operation().body(body).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is set to "gzip"
                            assert_eq!(Some("gzip"), request.headers().get("content-encoding"));
                        assert_eq!(
                            Some("gzip"),
                            request.headers().get(#{http}::header::CONTENT_ENCODING)
                        );

                            let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        let compressed_body = ByteStream::from(request.into_body())
                            .collect()
                            .await
                            .unwrap()
                            .to_vec();
                        // Assert input body is different from uncompressed input
                            assert_ne!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                        assert_ne!(
                            UNCOMPRESSED_INPUT,
                            compressed_body.as_slice(),
                            "input was not compressed"
                        );
                        // Assert input body was compressed
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice());
                    }

                    ##[#{tokio}::test]
                    async fn test_compressed_content_length() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            .request_min_compression_size_bytes(0)
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client
                            .some_operation()
                            .body(Blob::new(UNCOMPRESSED_INPUT))
                            .send()
                            .await;
                        let request = rx.expect_request();
                        // Check that the content-length header is set correctly.
                        if let Some(content_length) = request
                            .headers()
                            .get(#{http}::header::CONTENT_LENGTH)
                            .and_then(|len| len.parse::<usize>().ok())
                        {
                            assert_ne!(
                                content_length, UNCOMPRESSED_INPUT.len(),
                                "`content-length` of in-memory payload was incorrectly set to the length of the uncompressed input but should have been set to the length of the compressed payload"
                            );
                            assert_eq!(COMPRESSED_OUTPUT.len(), content_length);
                        }

                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            .request_min_compression_size_bytes(0)
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        // ByteStreams created from a file are streaming
                        let mut file = #{tempfile}::NamedTempFile::new().unwrap();
                        use std::io::Write;
                        file.write_all(UNCOMPRESSED_INPUT).unwrap();

                        let body = ByteStream::read_from()
                            .path(file.path())
                            .buffer_size(1024)
                            .length(::aws_smithy_types::byte_stream::Length::Exact(
                                UNCOMPRESSED_INPUT.len() as u64,
                            ))
                            .build()
                            .await
                            .unwrap();
                        let _ = client.some_streaming_operation().body(body).send().await;
                        let request = rx.expect_request();

                        assert!(
                            request
                                .headers()
                                .get(#{http}::header::CONTENT_LENGTH)
                                .is_none(),
                            "expected that no content length header is set because the request is streaming."
                        );
                    }
                    """,
                    *preludeScope,
                    "ByteStream" to RuntimeType.smithyTypes(rc).resolve("byte_stream::ByteStream"),
                    "Blob" to RuntimeType.smithyTypes(rc).resolve("Blob"),
                    "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
                    "http" to CargoDependency.Http.toType(),
                    "tokio" to CargoDependency.Tokio.toType(),
                    "capture_request" to RuntimeType.captureRequest(rc),
                    "pretty_assertions" to CargoDependency.PrettyAssertions.toType(),
Loading