Unverified Commit 37c1cc0b authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

add codegen test for sigv4a EPR (#3848)

exactly what it says on the tin. Let me know if we should have any other
tests verifying the function of sigv4a auth schemes.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent a64982a4
Loading
Loading
Loading
Loading
+112 −1
Original line number Diff line number Diff line
@@ -5,7 +5,12 @@
package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest

class SigV4AuthDecoratorTest {
    private val modelWithSigV4AuthScheme =
@@ -35,6 +40,7 @@ class SigV4AuthDecoratorTest {

        structure SomeInput {
            @httpPayload
            @required
            something: Bytestream
         }

@@ -56,10 +62,115 @@ class SigV4AuthDecoratorTest {
        @unsignedPayload
        @http(uri: "/", method: "POST")
        operation SomeOperation { input: SomeInput, output: SomeOutput }
        """.asSmithyModel()
        """.asSmithyModel(smithyVersion = "2.0")

    @Test
    fun unsignedPayloadSetsCorrectHeader() {
        awsSdkIntegrationTest(modelWithSigV4AuthScheme) { _, _ -> }
    }

    private val modelWithSigV4aAuthScheme =
        """
        namespace test

        use aws.auth#sigv4
        use aws.auth#sigv4a
        use aws.api#service
        use aws.protocols#restJson1
        use smithy.rules#endpointRuleSet
        use aws.auth#unsignedPayload
        use smithy.test#httpRequestTests

        @auth([sigv4a,sigv4])
        @sigv4(name: "dontcare")
        @sigv4a(name: "dontcare")
        @restJson1
        @endpointRuleSet({
            "version": "1.0",
            "rules": [
                {
                    "type": "endpoint",
                    "conditions": [],
                    "endpoint": {
                        "url": "https://example.com",
                        "properties": {
                            "authSchemes": [
                                {
                                    "name": "sigv4a",
                                    "signingRegionSet": ["*"],
                                    "signingName": "dontcare"
                                }
                            ]
                        }
                    }
                }
            ],
            "parameters": {
                "endpoint": { "required": true, "type": "string", "builtIn": "SDK::Endpoint" },
            }
        })
        @service(sdkId: "dontcare")
        service TestService { version: "2023-01-01", operations: [SomeOperation] }

        @streaming
        blob Bytestream

        structure SomeInput {
            @httpPayload
            @required
            something: Bytestream
         }

        structure SomeOutput { something: String }

        @http(uri: "/", method: "POST")
        operation SomeOperation { input: SomeInput, output: SomeOutput }
        """.asSmithyModel(smithyVersion = "2.0")

    @Test
    fun unsignedPayloadSetsCorrectHeaderForSigV4a() {
        awsSdkIntegrationTest(modelWithSigV4aAuthScheme) { clientCodegenContext, rustCrate ->
            val moduleUseName = clientCodegenContext.moduleUseName()
            val rc = clientCodegenContext.runtimeConfig

            rustCrate.integrationTest("sigv4a") {
                Attribute.featureGate("test-util").render(this)
                tokioTest("test_sigv4a_signing") {
                    rustTemplate(
                        """
                        let http_client = #{StaticReplayClient}::new(vec![#{ReplayEvent}::new(
                            #{Request}::builder()
                                .header("authorization", "AWS4-ECDSA-P256-SHA256 Credential=ANOTREAL/20090213/dontcare/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-region-set;x-amz-user-agent, Signature=3045022100b95d1c054ff04b676d12f0c893348606844d67ccf595981f0ca4968fae2eddfd022073e66edc0ad1da05b08392fccefa3ad69f8ec9393461033412fa05c55b749e9d")
                                .uri("https://example.com")
                                .body(#{SdkBody}::from("Hello, world!"))
                                .unwrap(),
                            #{Response}::builder().status(200).body(#{SdkBody}::empty()).unwrap(),
                        )]);
                        let config = $moduleUseName::Config::builder()
                            .http_client(http_client.clone())
                            .endpoint_url("https://example.com")
                            .behavior_version_latest()
                            .with_test_defaults()
                            .build();
                        let client = $moduleUseName::Client::from_conf(config);
                        let _ = client.some_operation().something(#{ByteStream}::from_static(b"Hello, world!")).send().await;

                        http_client.assert_requests_match(&["authorization"]);
                        let auth_header = http_client.actual_requests().next().unwrap().headers().get(http::header::AUTHORIZATION).unwrap();
                        assert!(auth_header.contains("AWS4-ECDSA-P256-SHA256"));
                        """,
                        "ByteStream" to RuntimeType.byteStream(rc),
                        "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc).resolve("Credentials"),
                        "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
                        "ReplayEvent" to RuntimeType.smithyRuntimeTestUtil(rc).resolve("ReplayEvent"),
                        "Request" to RuntimeType.HttpRequest,
                        "Response" to RuntimeType.HttpResponse,
                        "SdkBody" to RuntimeType.sdkBody(rc),
                        "StaticReplayClient" to RuntimeType.smithyRuntimeTestUtil(rc).resolve("StaticReplayClient"),
                        "tracing_subscriber" to RuntimeType.TracingSubscriber,
                    )
                }
            }
        }
    }
}
+6 −8
Original line number Diff line number Diff line
@@ -279,14 +279,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
        val Bytes = CargoDependency.Bytes.toType().resolve("Bytes")
        val Http = CargoDependency.Http.toType()
        val HttpBody = CargoDependency.HttpBody.toType()
        val HttpHeaderMap = Http.resolve("HeaderMap")
        val HttpRequest = Http.resolve("Request")
        val HttpRequestBuilder = Http.resolve("request::Builder")
        val HttpResponse = Http.resolve("Response")
        val HttpResponseBuilder = Http.resolve("response::Builder")
        val Hyper = CargoDependency.Hyper.toType()
        val LazyStatic = CargoDependency.LazyStatic.toType()
        val Md5 = CargoDependency.Md5.toType()
        val OnceCell = CargoDependency.OnceCell.toType()
        val PercentEncoding = CargoDependency.PercentEncoding.toType()
        val PrettyAssertions = CargoDependency.PrettyAssertions.toType()
@@ -294,12 +292,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
        val Serde = CargoDependency.Serde.toType()
        val SerdeDeserialize = Serde.resolve("Deserialize")
        val SerdeSerialize = Serde.resolve("Serialize")
        val RegexLite = CargoDependency.RegexLite.toType()
        val Tokio = CargoDependency.Tokio.toType()
        val TokioStream = CargoDependency.TokioStream.toType()
        val Tower = CargoDependency.Tower.toType()
        val Tracing = CargoDependency.Tracing.toType()
        val TracingTest = CargoDependency.TracingTest.toType()
        val TracingSubscriber = CargoDependency.TracingSubscriber.toType()

        // codegen types
        val ConstrainedTrait = RuntimeType("crate::constrained::Constrained", InlineDependency.constrained())
@@ -312,8 +310,6 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)

        fun smithyChecksums(runtimeConfig: RuntimeConfig) = CargoDependency.smithyChecksums(runtimeConfig).toType()

        fun smithyCompression(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCompression(runtimeConfig).toType()

        fun smithyEventStream(runtimeConfig: RuntimeConfig) = CargoDependency.smithyEventStream(runtimeConfig).toType()

        fun smithyHttp(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttp(runtimeConfig).toType()
@@ -444,7 +440,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
            forInlineDependency(InlineDependency.awsQueryCompatibleErrors(runtimeConfig))

        fun defaultAuthPlugin(runtimeConfig: RuntimeConfig) =
            RuntimeType.forInlineDependency(InlineDependency.defaultAuthPlugin(runtimeConfig))
            forInlineDependency(InlineDependency.defaultAuthPlugin(runtimeConfig))
                .resolve("DefaultAuthOptionsPlugin")

        fun labelFormat(
@@ -502,9 +498,11 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
            return smithyTypes(runtimeConfig).resolve("date_time::Format::$timestampFormat")
        }

        fun smithyRuntimeTestUtil(runtimeConfig: RuntimeConfig) =
            CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType().resolve("client::http::test_util")

        fun captureRequest(runtimeConfig: RuntimeConfig) =
            CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType()
                .resolve("client::http::test_util::capture_request")
            smithyRuntimeTestUtil(runtimeConfig).resolve("capture_request")

        fun forInlineDependency(inlineDependency: InlineDependency) =
            RuntimeType("crate::${inlineDependency.name}", inlineDependency)