Unverified Commit 336243c4 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

store identity resolvers in a map (#3363)



## Motivation and Context
- aws-sdk-rust#901
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->

## Description
Identity resolvers aren't a list—they are a map keyed on AuthSchemeId.
This updates the internal representation to match.

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
parent 30a801a8
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -28,3 +28,9 @@ message = "The `Metadata` storable was moved from aws_smithy_http into aws_smith
references = ["smithy-rs#3325"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Fix bug where overriding the credentials at the operation level failed if credentials were already set."
references = ["aws-sdk-rust#901", "smithy-rs#3363"]
meta = { "breaking" = false, "bug" = true, "tada" = false }
author = "rcoh"
+2 −33
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@
 */

use crate::box_error::BoxError;
use crate::client::auth::AuthSchemeId;
use crate::client::runtime_components::sealed::ValidateConfig;
use crate::client::runtime_components::{RuntimeComponents, RuntimeComponentsBuilder};
use crate::impl_shared_conversions;
@@ -108,6 +107,8 @@ impl ResolveCachedIdentity for SharedIdentityCache {
    }
}

impl ValidateConfig for SharedIdentityResolver {}

impl ValidateConfig for SharedIdentityCache {
    fn validate_base_client_config(
        &self,
@@ -197,38 +198,6 @@ impl ResolveIdentity for SharedIdentityResolver {

impl_shared_conversions!(convert SharedIdentityResolver from ResolveIdentity using SharedIdentityResolver::new);

/// An identity resolver paired with an auth scheme ID that it resolves for.
#[derive(Clone, Debug)]
pub(crate) struct ConfiguredIdentityResolver {
    auth_scheme: AuthSchemeId,
    identity_resolver: SharedIdentityResolver,
}

impl ConfiguredIdentityResolver {
    /// Creates a new [`ConfiguredIdentityResolver`] from the given auth scheme and identity resolver.
    pub(crate) fn new(
        auth_scheme: AuthSchemeId,
        identity_resolver: SharedIdentityResolver,
    ) -> Self {
        Self {
            auth_scheme,
            identity_resolver,
        }
    }

    /// Returns the auth scheme ID.
    pub(crate) fn scheme_id(&self) -> AuthSchemeId {
        self.auth_scheme
    }

    /// Returns the identity resolver.
    pub(crate) fn identity_resolver(&self) -> SharedIdentityResolver {
        self.identity_resolver.clone()
    }
}

impl ValidateConfig for ConfiguredIdentityResolver {}

/// An identity that can be used for authentication.
///
/// The [`Identity`] is a container for any arbitrary identity data that may be used
+56 −37
Original line number Diff line number Diff line
@@ -18,8 +18,7 @@ use crate::client::auth::{
use crate::client::endpoint::{ResolveEndpoint, SharedEndpointResolver};
use crate::client::http::{HttpClient, SharedHttpClient};
use crate::client::identity::{
    ConfiguredIdentityResolver, ResolveCachedIdentity, ResolveIdentity, SharedIdentityCache,
    SharedIdentityResolver,
    ResolveCachedIdentity, ResolveIdentity, SharedIdentityCache, SharedIdentityResolver,
};
use crate::client::interceptors::{Intercept, SharedInterceptor};
use crate::client::retries::classifiers::{ClassifyRetry, SharedRetryClassifier};
@@ -29,6 +28,7 @@ use crate::shared::IntoShared;
use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
use aws_smithy_async::time::{SharedTimeSource, TimeSource};
use aws_smithy_types::config_bag::ConfigBag;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;

@@ -177,6 +177,13 @@ macro_rules! merge {
            $self.$name.extend($other.$name.iter().cloned());
        }
    };
    (OptionalAuthSchemeMap $other:ident . $name:ident => $self:ident ) => {
        if let Some(m) = &$other.$name {
            let mut us = $self.$name.unwrap_or_default();
            us.extend(m.iter().map(|(k, v)| (k.clone(), v.clone())));
            $self.$name = Some(us);
        }
    };
}
/// Internal to `declare_runtime_components!`.
///
@@ -197,6 +204,18 @@ macro_rules! builder_field_value {
    (Vec $self:ident . $name:ident) => {
        $self.$name
    };
    (OptionalAuthSchemeMap $self:ident . $name:ident atLeastOneRequired) => {{
        match $self.$name {
            Some(map) => map,
            None => {
                return Err(BuildError(concat!(
                    "at least one `",
                    stringify!($name),
                    "` runtime component is required"
                )));
            }
        }
    }};
    (Vec $self:ident . $name:ident atLeastOneRequired) => {{
        if $self.$name.is_empty() {
            return Err(BuildError(concat!(
@@ -225,6 +244,7 @@ macro_rules! runtime_component_field_type {
    (Vec $inner_type:ident atLeastOneRequired) => {
        Vec<Tracked<$inner_type>>
    };
    (OptionalAuthSchemeMap $inner_type: ident atLeastOneRequired) => { AuthSchemeMap<Tracked<$inner_type>> };
}
/// Internal to `declare_runtime_components!`.
///
@@ -238,8 +258,14 @@ macro_rules! empty_builder_value {
    (Vec) => {
        Vec::new()
    };
    (OptionalAuthSchemeMap) => {
        None
    };
}

type OptionalAuthSchemeMap<V> = Option<AuthSchemeMap<V>>;
type AuthSchemeMap<V> = HashMap<AuthSchemeId, V>;

/// Macro to define the structs for both `RuntimeComponents` and `RuntimeComponentsBuilder`.
///
/// This is a macro in order to keep the fields consistent between the two, and to automatically
@@ -334,7 +360,7 @@ declare_runtime_components! {
        identity_cache: Option<SharedIdentityCache>,

        #[atLeastOneRequired]
        identity_resolvers: Vec<ConfiguredIdentityResolver>,
        identity_resolvers: OptionalAuthSchemeMap<SharedIdentityResolver>,

        interceptors: Vec<SharedInterceptor>,

@@ -429,7 +455,12 @@ impl RuntimeComponents {
                }
            };
            (Vec: $field:expr) => {
                for entry in &$field {
                for entry in $field {
                    ValidateConfig::validate_final_config(&entry.value, self, cfg)?;
                }
            };
            (Map: $field:expr) => {
                for entry in $field.values() {
                    ValidateConfig::validate_final_config(&entry.value, self, cfg)?;
                }
            };
@@ -441,10 +472,10 @@ impl RuntimeComponents {
        }
        validate!(Option: self.http_client);
        validate!(Required: self.endpoint_resolver);
        validate!(Vec: self.auth_schemes);
        validate!(Vec: &self.auth_schemes);
        validate!(Required: self.identity_cache);
        validate!(Vec: self.identity_resolvers);
        validate!(Vec: self.interceptors);
        validate!(Map: self.identity_resolvers);
        validate!(Vec: &self.interceptors);
        validate!(Required: self.retry_strategy);
        Ok(())
    }
@@ -570,11 +601,7 @@ impl RuntimeComponentsBuilder {
        scheme_id: AuthSchemeId,
        identity_resolver: impl ResolveIdentity + 'static,
    ) -> &mut Self {
        self.identity_resolvers.push(Tracked::new(
            self.builder_name,
            ConfiguredIdentityResolver::new(scheme_id, identity_resolver.into_shared()),
        ));
        self
        self.set_identity_resolver(scheme_id, identity_resolver)
    }

    /// Sets the identity resolver for a given `scheme_id`.
@@ -586,21 +613,12 @@ impl RuntimeComponentsBuilder {
        scheme_id: AuthSchemeId,
        identity_resolver: impl ResolveIdentity + 'static,
    ) -> &mut Self {
        let tracked = Tracked::new(
            self.builder_name,
            ConfiguredIdentityResolver::new(scheme_id, identity_resolver.into_shared()),
        let mut resolvers = self.identity_resolvers.take().unwrap_or_default();
        resolvers.insert(
            scheme_id,
            Tracked::new(self.builder_name, identity_resolver.into_shared()),
        );

        if let Some(s) = self
            .identity_resolvers
            .iter_mut()
            .find(|s| s.value.scheme_id() == scheme_id)
        {
            *s = tracked;
        } else {
            self.identity_resolvers.push(tracked);
        }

        self.identity_resolvers = Some(resolvers);
        self
    }

@@ -807,7 +825,7 @@ impl RuntimeComponentsBuilder {
    pub fn validate_base_client_config(&self, cfg: &ConfigBag) -> Result<(), BoxError> {
        macro_rules! validate {
            ($field:expr) => {
                for entry in &$field {
                for entry in $field {
                    ValidateConfig::validate_base_client_config(&entry.value, self, cfg)?;
                }
            };
@@ -817,13 +835,15 @@ impl RuntimeComponentsBuilder {
        for validator in self.config_validators() {
            validator.validate_base_client_config(self, cfg)?;
        }
        validate!(self.http_client);
        validate!(self.endpoint_resolver);
        validate!(self.auth_schemes);
        validate!(self.identity_cache);
        validate!(self.identity_resolvers);
        validate!(self.interceptors);
        validate!(self.retry_strategy);
        validate!(&self.http_client);
        validate!(&self.endpoint_resolver);
        validate!(&self.auth_schemes);
        validate!(&self.identity_cache);
        if let Some(resolvers) = &self.identity_resolvers {
            validate!(resolvers.values())
        }
        validate!(&self.interceptors);
        validate!(&self.retry_strategy);
        Ok(())
    }

@@ -1010,9 +1030,8 @@ pub trait GetIdentityResolver: Send + Sync {
impl GetIdentityResolver for RuntimeComponents {
    fn identity_resolver(&self, scheme_id: AuthSchemeId) -> Option<SharedIdentityResolver> {
        self.identity_resolvers
            .iter()
            .find(|s| s.value.scheme_id() == scheme_id)
            .map(|s| s.value.identity_resolver())
            .get(&scheme_id)
            .map(|s| s.value.clone())
    }
}