Unverified Commit 78e6636f authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Align auth implementation with SRA (#4198)

## Motivation and Context
This PR makes the auth implementation aligned with our design
specifications. The code changes in the PR consolidate the previous
sub-PRs that have been reviewed already:
- https://github.com/smithy-lang/smithy-rs/pull/4161
- https://github.com/smithy-lang/smithy-rs/pull/4175
- https://github.com/smithy-lang/smithy-rs/pull/4182
- https://github.com/smithy-lang/smithy-rs/pull/4183



For auth-related types users directly interact with in service config
such as `AuthScheme`, `AuthSchemeOption`, we can re-export them in a
client SDK upon a customer request.

## Testing
- All tests added in the sub-PRs above

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.
- [x] For changes to the AWS SDK, generated SDK code, or SDK runtime
crates, I have created a changelog entry Markdown file in the
`.changelog` directory, specifying "aws-sdk-rust" in the `applies_to`
key.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarAWS SDK Rust Bot <97246200+aws-sdk-rust-ci@users.noreply.github.com>
parent a98b5b60
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
---
applies_to:
- aws-sdk-rust
- client
authors:
- ysaito1001
references:
- smithy-rs#4076
- smithy-rs#4198
breaking: false
new_feature: false
bug_fix: false
---
Allows customers to configure the auth schemes and auth scheme resolver. For more information see the GitHub [discussion](https://github.com/smithy-lang/smithy-rs/discussions/4197).
+2 −2
Original line number Diff line number Diff line
@@ -143,7 +143,7 @@ dependencies = [

[[package]]
name = "aws-runtime"
version = "1.5.8"
version = "1.5.9"
dependencies = [
 "arbitrary",
 "aws-credential-types",
@@ -319,7 +319,7 @@ dependencies = [

[[package]]
name = "aws-smithy-runtime"
version = "1.8.3"
version = "1.8.4"
dependencies = [
 "aws-smithy-async",
 "aws-smithy-http",
+193 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use std::borrow::Cow;

use aws_smithy_runtime_api::{
    box_error::BoxError,
    client::{
        auth::{AuthSchemeId, AuthSchemeOption},
        endpoint::{EndpointResolverParams, ResolveEndpoint},
        runtime_components::RuntimeComponents,
    },
};
use aws_smithy_types::config_bag::ConfigBag;

pub(crate) async fn resolve_endpoint_based_auth_scheme_options<'a>(
    modeled_auth_scheme_options: &'a [AuthSchemeOption],
    cfg: &'a ConfigBag,
    runtime_components: &'a RuntimeComponents,
) -> Result<Vec<AuthSchemeOption>, BoxError> {
    let endpoint_params = cfg
        .load::<EndpointResolverParams>()
        .expect("endpoint resolver params must be set");

    tracing::debug!(endpoint_params = ?endpoint_params, "resolving endpoint for auth scheme selection");

    let endpoint = runtime_components
        .endpoint_resolver()
        .resolve_endpoint(endpoint_params)
        .await?;

    let mut endpoint_auth_scheme_ids = Vec::new();

    // Note that we're not constructing the `properties` for `endpoint_auth_schemes` here—only collecting
    // auth scheme IDs but not properties. This is because, at this stage, we're only determining which auth schemes will be candidates.
    // Any `authSchemes` list properties that influence the signing context will be extracted later
    // in `AuthSchemeEndpointConfig`, and passed by the orchestrator to the signer's `sign_http_request` method.
    if let Some(aws_smithy_types::Document::Array(endpoint_auth_schemes)) =
        endpoint.properties().get("authSchemes")
    {
        for endpoint_auth_scheme in endpoint_auth_schemes {
            let scheme_id_str = endpoint_auth_scheme
                .as_object()
                .and_then(|object| object.get("name"))
                .and_then(aws_smithy_types::Document::as_string);
            if let Some(scheme_id_str) = scheme_id_str {
                endpoint_auth_scheme_ids
                    .push(AuthSchemeId::from(Cow::Owned(scheme_id_str.to_owned())));
            }
        }
    }

    Ok(merge_auth_scheme_options(
        modeled_auth_scheme_options,
        endpoint_auth_scheme_ids,
    ))
}

// Returns a list of merged auth scheme options from `modeled_auth_scheme_options` and `endpoint_auth_scheme_ids`,
// copying properties from the modeled auth scheme options into the endpoint auth scheme options as they are built.
//
// Note: We only extract properties from the modeled auth schemes. Pulling properties from the endpoint auth schemes
// would result in duplication; they would be added here and again in the `extract_operation_config` function during signing.
fn merge_auth_scheme_options(
    modeled_auth_scheme_options: &[AuthSchemeOption],
    endpoint_auth_scheme_ids: Vec<AuthSchemeId>,
) -> Vec<AuthSchemeOption> {
    let (common_auth_scheme_options, model_only_auth_scheme_options): (Vec<_>, Vec<_>) =
        modeled_auth_scheme_options
            .iter()
            .partition(|auth_scheme_option| {
                endpoint_auth_scheme_ids.contains(auth_scheme_option.scheme_id())
            });

    let mut endpoint_auth_scheme_options = endpoint_auth_scheme_ids
        .into_iter()
        .map(|id| {
            let modelded = common_auth_scheme_options
                .iter()
                .find(|opt| opt.scheme_id() == &id)
                .cloned();
            let mut builder = AuthSchemeOption::builder().scheme_id(id);
            builder.set_properties(modelded.and_then(|m| m.properties()));
            builder.build().unwrap()
        })
        .collect::<Vec<_>>();

    endpoint_auth_scheme_options.extend(model_only_auth_scheme_options.into_iter().cloned());

    endpoint_auth_scheme_options
}

#[cfg(test)]
mod tests {
    use aws_runtime::auth::PayloadSigningOverride;
    use aws_smithy_types::config_bag::Layer;

    use super::*;

    fn into_auth_scheme_ids<const N: usize>(strs: [&'static str; N]) -> Vec<AuthSchemeId> {
        strs.into_iter().map(AuthSchemeId::from).collect::<Vec<_>>()
    }

    fn into_auth_scheme_options<const N: usize>(strs: [&'static str; N]) -> Vec<AuthSchemeOption> {
        strs.into_iter()
            .map(|s| AuthSchemeOption::from(AuthSchemeId::from(s)))
            .collect::<Vec<_>>()
    }

    #[test]
    fn merge_auth_scheme_options_basic() {
        let modeled_auth_scheme_options =
            into_auth_scheme_options(["schemeA", "schemeX", "schemeB", "schemeY"]);
        let endpoint_auth_scheme_ids = into_auth_scheme_ids(["schemeY", "schemeX"]);
        let expected = ["schemeY", "schemeX", "schemeA", "schemeB"];
        let actual =
            merge_auth_scheme_options(&modeled_auth_scheme_options, endpoint_auth_scheme_ids);
        assert_eq!(
            expected.to_vec(),
            actual
                .iter()
                .map(|opt| opt.scheme_id().inner())
                .collect::<Vec<_>>()
        );
    }

    #[test]
    fn merge_auth_scheme_options_preserving_modeled_auth_properties() {
        let mut modeled_auth_scheme_options =
            into_auth_scheme_options(["schemeA", "schemeX", "schemeB"]);
        modeled_auth_scheme_options.push(
            AuthSchemeOption::builder()
                .scheme_id(AuthSchemeId::new("schemeY"))
                .properties({
                    let mut layer = Layer::new("TestAuthSchemeProperties");
                    layer.store_put(PayloadSigningOverride::unsigned_payload());
                    layer.freeze()
                })
                .build()
                .unwrap(),
        );
        let endpoint_auth_scheme_ids = into_auth_scheme_ids(["schemeY", "schemeX"]);
        let expected = ["schemeY", "schemeX", "schemeA", "schemeB"];
        let actual =
            merge_auth_scheme_options(&modeled_auth_scheme_options, endpoint_auth_scheme_ids);
        assert_eq!(
            expected.to_vec(),
            actual
                .iter()
                .map(|opt| opt.scheme_id().inner())
                .collect::<Vec<_>>()
        );
        let prop = actual.first().unwrap().properties().unwrap();
        assert!(matches!(
            prop.load::<PayloadSigningOverride>().unwrap(),
            PayloadSigningOverride::UnsignedPayload
        ));
    }

    #[test]
    fn merge_auth_scheme_options_with_empty_endpoint_auth_scheme_options() {
        let expected = ["schemeA", "schemeX", "schemeB", "schemeY"];
        let modeled_auth_scheme_options = into_auth_scheme_options(expected);
        let endpoint_auth_scheme_ids = Vec::new();
        let actual =
            merge_auth_scheme_options(&modeled_auth_scheme_options, endpoint_auth_scheme_ids);
        assert_eq!(
            expected.to_vec(),
            actual
                .iter()
                .map(|opt| opt.scheme_id().inner())
                .collect::<Vec<_>>()
        );
    }

    #[test]
    fn merge_auth_scheme_options_should_also_include_those_only_in_endpoint_auth_scheme_options() {
        let modeled_auth_scheme_ids =
            into_auth_scheme_options(["schemeA", "schemeX", "schemeB", "schemeY"]);
        let endpoint_auth_scheme_ids = into_auth_scheme_ids(["schemeY", "schemeZ"]);
        let expected = ["schemeY", "schemeZ", "schemeA", "schemeX", "schemeB"];
        let actual = merge_auth_scheme_options(&modeled_auth_scheme_ids, endpoint_auth_scheme_ids);
        assert_eq!(
            expected.to_vec(),
            actual
                .iter()
                .map(|opt| opt.scheme_id().inner())
                .collect::<Vec<_>>()
        );
    }
}
+0 −170
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use std::borrow::Cow;

use aws_smithy_runtime_api::client::{
    auth::{AuthSchemeId, AuthSchemeOption},
    endpoint::ResolveEndpoint,
    runtime_components::RuntimeComponentsBuilder,
    runtime_plugin::{Order, RuntimePlugin},
};

// A runtime plugin that registers `EndpointBasedAuthSchemeOptionResolver` with `RuntimeComponents`.
#[derive(Debug)]
pub(crate) struct EndpointBasedAuthOptionsPlugin {
    runtime_components: RuntimeComponentsBuilder,
}

impl EndpointBasedAuthOptionsPlugin {
    pub(crate) fn new(auth_schemes: Vec<AuthSchemeId>) -> Self {
        let runtime_components = RuntimeComponentsBuilder::new("endpoint_based_auth_options")
            .with_auth_scheme_option_resolver(Some(EndpointBasedAuthSchemeOptionResolver::new(
                auth_schemes,
            )));
        Self { runtime_components }
    }
}

impl RuntimePlugin for EndpointBasedAuthOptionsPlugin {
    fn order(&self) -> Order {
        Order::Defaults
    }

    fn runtime_components(
        &self,
        _current_components: &RuntimeComponentsBuilder,
    ) -> Cow<'_, RuntimeComponentsBuilder> {
        Cow::Borrowed(&self.runtime_components)
    }
}

// An `AuthSchemeOptionResolver` that prioritizes auth scheme options from an internally resolved endpoint.
//
// The code generator provides `modeled_auth_scheme_ids`, but during auth scheme option resolution, their priority is
// overridden by the auth scheme options specified in the endpoint. `resolve_auth_scheme_options_v2` places the endpoint's
// options at the beginning of the resulting list. Furthermore, if the endpoint includes an unmodeled auth scheme option, this resolver
// will dynamically generate the option and add it to the resulting list.
#[derive(Debug)]
pub(crate) struct EndpointBasedAuthSchemeOptionResolver {
    modeled_auth_scheme_ids: Vec<AuthSchemeId>,
}

impl EndpointBasedAuthSchemeOptionResolver {
    /// Creates a new instance of `EndpointBasedAuthSchemeOptionResolver`.
    pub(crate) fn new(modeled_auth_scheme_ids: Vec<AuthSchemeId>) -> Self {
        Self {
            modeled_auth_scheme_ids,
        }
    }
}

impl aws_smithy_runtime_api::client::auth::ResolveAuthSchemeOptions
    for EndpointBasedAuthSchemeOptionResolver
{
    fn resolve_auth_scheme_options_v2<'a>(
        &'a self,
        _params: &'a aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams,
        cfg: &'a aws_smithy_types::config_bag::ConfigBag,
        runtime_components: &'a aws_smithy_runtime_api::client::runtime_components::RuntimeComponents,
    ) -> aws_smithy_runtime_api::client::auth::AuthSchemeOptionsFuture<'a> {
        aws_smithy_runtime_api::client::auth::AuthSchemeOptionsFuture::new(async move {
            let endpoint_params = cfg
                .load::<aws_smithy_runtime_api::client::endpoint::EndpointResolverParams>()
                .expect("endpoint resolver params must be set");

            tracing::debug!(endpoint_params = ?endpoint_params, "resolving endpoint for auth scheme selection");

            let endpoint = runtime_components
                .endpoint_resolver()
                .resolve_endpoint(endpoint_params)
                .await?;

            let mut endpoint_auth_scheme_ids = Vec::new();

            if let Some(aws_smithy_types::Document::Array(endpoint_auth_schemes)) =
                endpoint.properties().get("authSchemes")
            {
                for endpoint_auth_scheme in endpoint_auth_schemes {
                    let scheme_id_str = endpoint_auth_scheme
                        .as_object()
                        .and_then(|object| object.get("name"))
                        .and_then(aws_smithy_types::Document::as_string);
                    if let Some(scheme_id_str) = scheme_id_str {
                        endpoint_auth_scheme_ids
                            .push(AuthSchemeId::from(Cow::Owned(scheme_id_str.to_owned())));
                    }
                }
            }

            let result =
                merge_auth_scheme_ids(&self.modeled_auth_scheme_ids, endpoint_auth_scheme_ids);

            Ok(result
                .into_iter()
                .map(|auth_scheme_id| {
                    AuthSchemeOption::builder()
                        .scheme_id(auth_scheme_id)
                        .build()
                        .expect("required fields set")
                })
                .collect::<Vec<_>>())
        })
    }
}

// Merge a list of `AuthSchemeId`s both in `modeled_auth_scheme_ids` and in `endpoint_auth_scheme_ids`,
// but with those in `endpoint_auth_scheme_ids` placed at the front of the resulting list.
fn merge_auth_scheme_ids(
    modeled_auth_scheme_ids: &[AuthSchemeId],
    mut endpoint_auth_scheme_ids: Vec<AuthSchemeId>,
) -> Vec<AuthSchemeId> {
    let (_, model_only_auth_scheme_ids): (Vec<_>, Vec<_>) = modeled_auth_scheme_ids
        .iter()
        .partition(|auth_scheme_id| endpoint_auth_scheme_ids.contains(auth_scheme_id));

    endpoint_auth_scheme_ids.extend(model_only_auth_scheme_ids.into_iter().cloned());

    endpoint_auth_scheme_ids
}

#[cfg(test)]
mod tests {
    use super::*;

    fn into_auth_scheme_ids<const N: usize>(strs: [&'static str; N]) -> Vec<AuthSchemeId> {
        strs.into_iter().map(AuthSchemeId::from).collect::<Vec<_>>()
    }

    #[test]
    fn merge_auth_scheme_ids_basic() {
        let modeled_auth_scheme_ids =
            into_auth_scheme_ids(["schemeA", "schemeX", "schemeB", "schemeY"]);
        let endpoint_auth_scheme_ids = into_auth_scheme_ids(["schemeY", "schemeX"]);
        let expected = into_auth_scheme_ids(["schemeY", "schemeX", "schemeA", "schemeB"]);
        let actual = merge_auth_scheme_ids(&modeled_auth_scheme_ids, endpoint_auth_scheme_ids);
        assert_eq!(expected, actual);
    }

    #[test]
    fn merge_auth_scheme_ids_with_empty_endpoint_auth_scheme_ids() {
        let modeled_auth_scheme_ids =
            into_auth_scheme_ids(["schemeA", "schemeX", "schemeB", "schemeY"]);
        let endpoint_auth_scheme_ids = Vec::new();
        let actual = merge_auth_scheme_ids(&modeled_auth_scheme_ids, endpoint_auth_scheme_ids);
        assert_eq!(modeled_auth_scheme_ids, actual);
    }

    #[test]
    fn merge_auth_scheme_ids_should_also_include_those_only_in_endpoint_auth_scheme_ids() {
        let modeled_auth_scheme_ids =
            into_auth_scheme_ids(["schemeA", "schemeX", "schemeB", "schemeY"]);
        let endpoint_auth_scheme_ids = into_auth_scheme_ids(["schemeY", "schemeZ"]);
        let expected =
            into_auth_scheme_ids(["schemeY", "schemeZ", "schemeA", "schemeX", "schemeB"]);
        let actual = merge_auth_scheme_ids(&modeled_auth_scheme_ids, endpoint_auth_scheme_ids);
        assert_eq!(expected, actual);
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -22,9 +22,9 @@
    unreachable_pub
)]

/// An endpoint-based `AuthSchemeOptionResolver` type, with `RuntimePlugin` trait impl for registering it in `RuntimeComponents`.
/// Supporting code to determine auth scheme options based on the `authSchemes` endpoint list property.
#[allow(dead_code)]
pub mod endpoint_auth_plugin;
pub mod endpoint_auth;

/// Interceptors for API Gateway
pub mod apigateway_interceptors;
Loading