Commit 3af97de9 authored by ysaito1001's avatar ysaito1001 Committed by Russell Cohen
Browse files

Ensure identity resolver is set when a credentials provider is given only at...

Ensure identity resolver is set when a credentials provider is given only at operation level (#3156)

Fixes https://github.com/awslabs/aws-sdk-rust/issues/901

This PR is a rework of https://github.com/awslabs/smithy-rs/pull/3021
whose fix was inadvertently discarded during
https://github.com/awslabs/smithy-rs/pull/3077. The way we fix the issue
is slightly different. In this PR, we add an identity resolver to
runtime components within `set_credentials_provider`, instead of using
`ServiceConfig.OperationConfigOverride`.

Added a Kotlin integration test to `CredentialProviderConfigTest.kt`
based on the customer reported issue.

<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent adb4b112
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -16,3 +16,9 @@ message = "Fix aws-sdk-rust#930 (PutSnapshotBlock)"
references = ["smithy-rs#3126", "aws-sdk-rust#930"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "rcoh"

[[aws-sdk-rust]]
message = "Fix exclusively setting the credentials provider at operation config-override time. It's now possible to set the credentials when an operation is sent (via `.config_override()`), rather than at client-creation time."
references = ["smithy-rs#3156", "aws-sdk-rust#901"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "ysaito1001"
+29 −53
Original line number Diff line number Diff line
@@ -10,17 +10,13 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.TestUtilFeature
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.supportedAuthSchemes
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.featureGateBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
@@ -30,12 +26,6 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
    override val name: String = "CredentialsProvider"
    override val order: Byte = 0

    override fun serviceRuntimePluginCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ServiceRuntimePluginCustomization>,
    ): List<ServiceRuntimePluginCustomization> =
        baseCustomizations + listOf(CredentialsIdentityResolverRegistration(codegenContext))

    override fun configCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
@@ -65,7 +55,7 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
/**
 * Add a `.credentials_provider` field and builder to the `Config` for a given service
 */
class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomization() {
class CredentialProviderConfig(private val codegenContext: ClientCodegenContext) : ConfigCustomization() {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val codegenScope = arrayOf(
        *preludeScope,
@@ -74,6 +64,10 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
            .resolve("provider::ProvideCredentials"),
        "SharedCredentialsProvider" to AwsRuntimeType.awsCredentialTypes(runtimeConfig)
            .resolve("provider::SharedCredentialsProvider"),
        "SIGV4A_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
            .resolve("auth::sigv4a::SCHEME_ID"),
        "SIGV4_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
            .resolve("auth::sigv4::SCHEME_ID"),
        "TestCredentials" to AwsRuntimeType.awsCredentialTypesTestUtil(runtimeConfig).resolve("Credentials"),
    )

@@ -103,60 +97,42 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
                    *codegenScope,
                )

                rustTemplate(
                rustBlockTemplate(
                    """
                    /// Sets the credentials provider for this service
                    pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self {
                        self.config.store_or_unset(credentials_provider);
                        self
                    }
                    pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self
                    """,
                    *codegenScope,
                ) {
                    rustBlockTemplate(
                        """
                        if let Some(credentials_provider) = credentials_provider
                        """,
                        *codegenScope,
                    ) {
                        if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
                            featureGateBlock("sigv4a") {
                                rustTemplate(
                                    "self.runtime_components.push_identity_resolver(#{SIGV4A_SCHEME_ID}, credentials_provider.clone());",
                                    *codegenScope,
                                )
                            }

            is ServiceConfig.DefaultForTests -> rustTemplate(
                "${section.configBuilderRef}.set_credentials_provider(Some(#{SharedCredentialsProvider}::new(#{TestCredentials}::for_tests())));",
                        }
                        rustTemplate(
                            "self.runtime_components.push_identity_resolver(#{SIGV4_SCHEME_ID}, credentials_provider);",
                            *codegenScope,
                        )

            else -> emptySection
                    }
                    rust("self")
                }
            }

class CredentialsIdentityResolverRegistration(
    private val codegenContext: ClientCodegenContext,
) : ServiceRuntimePluginCustomization() {
    private val runtimeConfig = codegenContext.runtimeConfig

    override fun section(section: ServiceRuntimePluginSection): Writable = writable {
        when (section) {
            is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
                rustBlockTemplate("if let Some(creds_provider) = ${section.serviceConfigName}.credentials_provider()") {
                    val codegenScope = arrayOf(
                        "SharedIdentityResolver" to RuntimeType.smithyRuntimeApi(runtimeConfig)
                            .resolve("client::identity::SharedIdentityResolver"),
                        "SIGV4A_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
                            .resolve("auth::sigv4a::SCHEME_ID"),
                        "SIGV4_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
                            .resolve("auth::sigv4::SCHEME_ID"),
            is ServiceConfig.DefaultForTests -> rustTemplate(
                "${section.configBuilderRef}.set_credentials_provider(Some(#{SharedCredentialsProvider}::new(#{TestCredentials}::for_tests())));",
                *codegenScope,
            )

                    if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
                        featureGateBlock("sigv4a") {
                            section.registerIdentityResolver(this) {
                                rustTemplate("#{SIGV4A_SCHEME_ID}, creds_provider.clone()", *codegenScope)
                            }
                        }
                    }
                    section.registerIdentityResolver(this) {
                        rustTemplate("#{SIGV4_SCHEME_ID}, creds_provider,", *codegenScope)
                    }
                }
            }

            else -> {}
            else -> emptySection
        }
    }
}
+54 −0
Original line number Diff line number Diff line
@@ -7,6 +7,10 @@ package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest

internal class CredentialProviderConfigTest {
    @Test
@@ -14,4 +18,54 @@ internal class CredentialProviderConfigTest {
        val codegenContext = awsTestCodegenContext()
        validateConfigCustomizations(codegenContext, CredentialProviderConfig(codegenContext))
    }

    @Test
    fun `configuring credentials provider at operation level should work`() {
        awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { ctx, rustCrate ->
            val rc = ctx.runtimeConfig
            val codegenScope = arrayOf(
                *RuntimeType.preludeScope,
                "capture_request" to RuntimeType.captureRequest(rc),
                "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc)
                    .resolve("Credentials"),
                "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
            )
            rustCrate.integrationTest("credentials_provider") {
                // per https://github.com/awslabs/aws-sdk-rust/issues/901
                tokioTest("configuring_credentials_provider_at_operation_level_should_work") {
                    val moduleName = ctx.moduleUseName()
                    rustTemplate(
                        """
                        let (http_client, _rx) = #{capture_request}(None);
                        let client_config = $moduleName::Config::builder()
                            .http_client(http_client)
                            .build();

                        let client = $moduleName::Client::from_conf(client_config);

                        let credentials = #{Credentials}::new(
                            "test",
                            "test",
                            #{None},
                            #{None},
                            "test",
                        );
                        let operation_config_override = $moduleName::Config::builder()
                            .credentials_provider(credentials.clone())
                            .region(#{Region}::new("us-west-2"));

                        let _ = client
                            .some_operation()
                            .customize()
                            .config_override(operation_config_override)
                            .send()
                            .await
                            .expect("success");
                        """,
                        *codegenScope,
                    )
                }
            }
        }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ async fn generate_random() {
            .header("content-type", "application/x-amz-json-1.1")
            .header("x-amz-target", "TrentService.GenerateRandom")
            .header("content-length", "20")
            .header("authorization", "AWS4-HMAC-SHA256 Credential=ANOTREAL/20090213/us-east-1/kms/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-target;x-amz-user-agent, Signature=53dcf70f6f852cb576185dcabef5aaa3d068704cf1b7ea7dc644efeaa46674d7")
            .header("authorization", "AWS4-HMAC-SHA256 Credential=ANOTREAL/20090213/us-east-1/kms/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token;x-amz-target;x-amz-user-agent, Signature=703f72fe50c310e3ee1a7a106df947b980cb91bc8bad7a4a603b057096603aed")
            .header("x-amz-date", "20090213T233130Z")
            .header("user-agent", "aws-sdk-rust/0.123.test os/windows/XPSP3 lang/rust/1.50.0")
            .header("x-amz-user-agent", "aws-sdk-rust/0.123.test api/test-service/0.123 os/windows/XPSP3 lang/rust/1.50.0")
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ async fn signv4_use_correct_service_name() {
            .header("content-type", "application/x-amz-json-1.0")
            .header("x-amz-target", "QLDBSession.SendCommand")
            .header("content-length", "49")
            .header("authorization", "AWS4-HMAC-SHA256 Credential=ANOTREAL/20090213/us-east-1/qldb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-target;x-amz-user-agent, Signature=9a07c60550504d015fb9a2b0f1b175a4d906651f9dd4ee44bebb32a802d03815")
            .header("authorization", "AWS4-HMAC-SHA256 Credential=ANOTREAL/20090213/us-east-1/qldb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token;x-amz-target;x-amz-user-agent, Signature=e8d50282fa369adf05f33a5b32e3ce2a7582edc902312c59de311001a97426d9")
            // qldbsession uses the signing name 'qldb' in signature _________________________^^^^
            .header("x-amz-date", "20090213T233130Z")
            .header("user-agent", "aws-sdk-rust/0.123.test os/windows/XPSP3 lang/rust/1.50.0")
Loading