Unverified Commit a4301fa4 authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

Request Compression PR 3/3 - Codegen (#3638)



PR 3 of 3. Most of the compression tests happen here since we don't have
a request-compression-supporting model in our set of "smoke test"
service models.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarysaito1001 <awsaito@amazon.com>
Co-authored-by: default avatarAaron Todd <aajtodd@users.noreply.github.com>
parent 0b6b121c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ val DECORATORS: List<ClientCodegenDecorator> =
            RemoveDefaultsDecorator(),
            TokenProvidersDecorator(),
            ServiceEnvConfigDecorator(),
            HttpRequestCompressionDecorator(),
        ),
        // Service specific decorators
        ApiGatewayDecorator().onlyApplyTo("com.amazonaws.apigateway#BackplaneControlService"),
+137 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.traits.RequestCompressionTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class HttpRequestCompressionDecorator : ClientCodegenDecorator {
    override val name: String = "HttpRequestCompression"
    override val order: Byte = 0

    private fun usesRequestCompression(codegenContext: ClientCodegenContext): Boolean =
        codegenContext.model.isTraitApplied(RequestCompressionTrait::class.java)

    override fun configCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
    ): List<ConfigCustomization> {
        return baseCustomizations +
            usesRequestCompression(codegenContext).thenSingletonListOf {
                HttpRequestCompressionConfigCustomization(codegenContext)
            }
    }

    override fun extraSections(codegenContext: ClientCodegenContext): List<AdHocCustomization> {
        return usesRequestCompression(codegenContext).thenSingletonListOf {
            adhocCustomization<SdkConfigSection.CopySdkConfigToClientConfig> { section ->
                rust(
                    """
                    ${section.serviceConfigBuilder} = ${section.serviceConfigBuilder}
                        .disable_request_compression(${section.sdkConfig}.disable_request_compression());
                    ${section.serviceConfigBuilder} = ${section.serviceConfigBuilder}
                        .request_min_compression_size_bytes(${section.sdkConfig}.request_min_compression_size_bytes());
                    """,
                )
            }
        }
    }
}

class HttpRequestCompressionConfigCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope =
        arrayOf(
            "DisableRequestCompression" to RuntimeType.clientRequestCompression(runtimeConfig).resolve("DisableRequestCompression"),
            "RequestMinCompressionSizeBytes" to RuntimeType.clientRequestCompression(runtimeConfig).resolve("RequestMinCompressionSizeBytes"),
            "Storable" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::Storable"),
            "StoreReplace" to RuntimeType.smithyTypes(runtimeConfig).resolve("config_bag::StoreReplace"),
            *preludeScope,
        )

    override fun section(section: ServiceConfig) =
        writable {
            when (section) {
                ServiceConfig.ConfigImpl -> {
                    rustTemplate(
                        """
                        /// Returns the `disable request compression` setting, if it was provided.
                        pub fn disable_request_compression(&self) -> #{Option}<bool> {
                            self.config.load::<#{DisableRequestCompression}>().map(|it| it.0)
                        }

                        /// Returns the `request minimum compression size in bytes`, if it was provided.
                        pub fn request_min_compression_size_bytes(&self) -> #{Option}<u32> {
                            self.config.load::<#{RequestMinCompressionSizeBytes}>().map(|it| it.0)
                        }
                        """,
                        *codegenScope,
                    )
                }

                ServiceConfig.BuilderImpl -> {
                    rustTemplate(
                        """
                        /// Sets the `disable request compression` used when making requests.
                        pub fn disable_request_compression(mut self, disable_request_compression: impl #{Into}<#{Option}<bool>>) -> Self {
                            self.set_disable_request_compression(disable_request_compression.into());
                            self
                        }

                        /// Sets the `request minimum compression size in bytes` used when making requests.
                        pub fn request_min_compression_size_bytes(mut self, request_min_compression_size_bytes: impl #{Into}<#{Option}<u32>>) -> Self {
                            self.set_request_min_compression_size_bytes(request_min_compression_size_bytes.into());
                            self
                        }
                        """,
                        *codegenScope,
                    )

                    rustTemplate(
                        """
                        /// Sets the `disable request compression` used when making requests.
                        pub fn set_disable_request_compression(&mut self, disable_request_compression: #{Option}<bool>) -> &mut Self {
                            self.config.store_or_unset::<#{DisableRequestCompression}>(disable_request_compression.map(Into::into));
                            self
                        }

                        /// Sets the `request minimum compression size in bytes` used when making requests.
                        pub fn set_request_min_compression_size_bytes(&mut self, request_min_compression_size_bytes: #{Option}<u32>) -> &mut Self {
                            self.config.store_or_unset::<#{RequestMinCompressionSizeBytes}>(request_min_compression_size_bytes.map(Into::into));
                            self
                        }
                        """,
                        *codegenScope,
                    )
                }

                is ServiceConfig.BuilderFromConfigBag -> {
                    rustTemplate(
                        """
                        ${section.builder}.set_disable_request_compression(
                            ${section.configBag}.load::<#{DisableRequestCompression}>().cloned().map(|it| it.0));
                        ${section.builder}.set_request_min_compression_size_bytes(
                            ${section.configBag}.load::<#{RequestMinCompressionSizeBytes}>().cloned().map(|it| it.0));
                        """,
                        *codegenScope,
                    )
                }

                else -> emptySection
            }
        }
}
+322 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest

class HttpRequestCompressionDecoratorTest {
    companion object {
        // Can't use the dollar sign in a multiline string with doing it like this.
        private const val PREFIX = "\$version: \"2\""
        val model =
            """
            $PREFIX
            namespace test

            use aws.api#service
            use aws.auth#sigv4
            use aws.protocols#restJson1
            use smithy.rules#endpointRuleSet

            @service(sdkId: "dontcare")
            @restJson1
            @sigv4(name: "dontcare")
            @auth([sigv4])
            @endpointRuleSet({
                "version": "1.0",
                "rules": [{ "type": "endpoint", "conditions": [], "endpoint": { "url": "https://example.com" } }],
                "parameters": {
                    "Region": { "required": false, "type": "String", "builtIn": "AWS::Region" },
                }
            })
            service TestService {
                version: "2023-01-01",
                operations: [SomeOperation, SomeStreamingOperation, NotACompressibleOperation]
            }

            @streaming
            blob StreamingBlob

            blob NonStreamingBlob

            @http(uri: "/SomeOperation", method: "POST")
            @optionalAuth
            @requestCompression(encodings: ["gzip"])
            operation SomeOperation {
                input: SomeInput,
                output: SomeOutput
            }

            @input
            structure SomeInput {
                @httpPayload
                @required
                body: NonStreamingBlob
            }

            @output
            structure SomeOutput {}

            @http(uri: "/SomeStreamingOperation", method: "POST")
            @optionalAuth
            @requestCompression(encodings: ["gzip"])
            operation SomeStreamingOperation {
                input: SomeStreamingInput,
                output: SomeStreamingOutput
            }

            @input
            structure SomeStreamingInput {
                @httpPayload
                @required
                body: StreamingBlob
            }

            @output
            structure SomeStreamingOutput {}

            @http(uri: "/NotACompressibleOperation", method: "PUT")
            @optionalAuth
            operation NotACompressibleOperation {
                input: SomeIncompressibleInput,
                output: SomeIncompressibleOutput
            }

            @input
            structure SomeIncompressibleInput {
                @httpPayload
                @required
                body: NonStreamingBlob
            }

            @output
            structure SomeIncompressibleOutput {}
            """.asSmithyModel()
    }

    @Test
    fun smokeTestSdkCodegen() {
        awsSdkIntegrationTest(model) { _, _ ->
            // it should compile
        }
    }

    @Test
    fun requestCompressionWorks() {
        awsSdkIntegrationTest(model) { context, rustCrate ->
            val rc = context.runtimeConfig
            val moduleName = context.moduleUseName()
            rustCrate.integrationTest("request_compression") {
                rustTemplate(
                    """
                    ##![cfg(feature = "test-util")]

                    use #{ByteStream};
                    use #{Blob};
                    use #{Region};
                    use #{pretty_assertions}::{assert_eq, assert_ne};

                    const UNCOMPRESSED_INPUT: &[u8] = b"Action=PutMetricData&Version=2010-08-01&Namespace=Namespace&MetricData.member.1.MetricName=metric&MetricData.member.1.Unit=Bytes&MetricData.member.1.Value=128";
                    // This may break if we ever change the default compression level.
                    const COMPRESSED_OUTPUT: &[u8] = &[
                        31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 1, 115, 0, 140, 255, 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 109,
                        139, 49, 14, 128, 32, 16, 4, 127, 67, 39, 1, 43, 155, 43, 52, 182, 26, 27, 233, 79, 114, 5,
                        137, 160, 129, 163, 240, 247, 6, 77, 180, 161, 155, 204, 206, 246, 150, 221, 17, 96, 201, 60,
                        17, 71, 103, 71, 100, 20, 134, 98, 42, 182, 85, 90, 53, 170, 107, 148, 22, 51, 122, 74, 39, 90,
                        130, 143, 196, 255, 144, 158, 252, 70, 81, 106, 249, 186, 210, 128, 127, 176, 90, 173, 193, 49,
                        12, 23, 83, 170, 206, 6, 247, 76, 160, 219, 238, 6, 30, 221, 9, 253, 158, 0, 0, 0, 160, 51, 48,
                        147, 115, 0, 0, 0,
                    ];

                    ##[#{tokio}::test]
                    async fn test_request_compression_isnt_applied_unless_modeled() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(true)
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.not_a_compressible_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set.
                        assert_eq!(None, request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_compression_can_be_disabled() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(true)
                            .build();

                        let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_min_size_body_over_minimum() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            .request_min_compression_size_bytes(128)
                            .build();

                            let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is set to "gzip"
                        assert_eq!(Some("gzip"), request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was compressed
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_min_size_body_under_minimum() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            .request_min_compression_size_bytes(256)
                            .build();

                            let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_compression_implicitly_enabled() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .request_min_compression_size_bytes(128)
                            .build();

                            let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is set to "gzip"
                        assert_eq!(Some("gzip"), request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was compressed
                        assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_compression_min_size_default() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            .build();

                            let client = $moduleName::Client::from_conf(config);
                        let _ = client.some_operation().body(Blob::new(UNCOMPRESSED_INPUT)).send().await;
                        let request = rx.expect_request();
                        // Check that the content-encoding header is not set to "gzip"
                        assert_ne!(Some("gzip"), request.headers().get("content-encoding"));

                        let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                        // Assert input body was not compressed
                        assert_eq!(UNCOMPRESSED_INPUT, compressed_body.as_slice())
                    }

                    ##[#{tokio}::test]
                    async fn test_request_compression_streaming_body() {
                        let (http_client, rx) = ::aws_smithy_runtime::client::http::test_util::capture_request(None);
                        let config = $moduleName::Config::builder()
                            .region(Region::from_static("doesntmatter"))
                            .with_test_defaults()
                            .http_client(http_client)
                            .disable_request_compression(false)
                            // Since our streaming body is sized, we have to set this.
                            .request_min_compression_size_bytes(128)
                            .build();

                            let client = $moduleName::Client::from_conf(config);
                            // ByteStreams created from a file are streaming and have a known size
                            let mut file = #{tempfile}::NamedTempFile::new().unwrap();
                            use std::io::Write;
                            file.write_all(UNCOMPRESSED_INPUT).unwrap();

                            let body = ByteStream::read_from()
                                .path(file.path())
                                .buffer_size(1024)
                                .length(#{Length}::Exact(UNCOMPRESSED_INPUT.len() as u64))
                                .build()
                                .await
                                .unwrap();
                            let _ = client
                                .some_streaming_operation()
                                .body(body)
                                .send()
                                .await;
                            let request = rx.expect_request();
                            // Check that the content-encoding header is set to "gzip"
                            assert_eq!(Some("gzip"), request.headers().get("content-encoding"));

                            let compressed_body = ByteStream::from(request.into_body()).collect().await.unwrap().to_vec();
                            // Assert input body is different from uncompressed input
                            assert_ne!(UNCOMPRESSED_INPUT, compressed_body.as_slice());
                            // Assert input body was compressed
                            assert_eq!(COMPRESSED_OUTPUT, compressed_body.as_slice());
                    }
                    """,
                    *preludeScope,
                    "ByteStream" to RuntimeType.smithyTypes(rc).resolve("byte_stream::ByteStream"),
                    "Blob" to RuntimeType.smithyTypes(rc).resolve("Blob"),
                    "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
                    "tokio" to CargoDependency.Tokio.toType(),
                    "capture_request" to RuntimeType.captureRequest(rc),
                    "pretty_assertions" to CargoDependency.PrettyAssertions.toType(),
                    "tempfile" to CargoDependency.TempFile.toType(),
                    "Length" to RuntimeType.smithyTypes(rc).resolve("byte_stream::Length"),
                )
            }
        }
    }
}
+0 −1
Original line number Diff line number Diff line
@@ -38,7 +38,6 @@ tokio = { version = "1.32.0", features = ["macros", "rt"] }
[target.'cfg(all(target_family = "wasm", target_os = "wasi"))'.dependencies]
wit-bindgen = { version = "0.16.0", features = ["macros", "realloc"] }


[lib]
crate-type = ["cdylib"]

+68 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading