Commit 29732031 authored by ysaito1001's avatar ysaito1001
Browse files
parent 27102829
Loading
Loading
Loading
Loading
+28 −32
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.ConditionalDe
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
@@ -238,39 +240,33 @@ class AccountIdEndpointModeBuiltInParamDecorator : ConditionalDecorator(
                                )
                            }
                        }
                    },
                )

                        override fun trackSdkFeatures(
            override fun serviceRuntimePluginCustomizations(
                codegenContext: ClientCodegenContext,
                            configBag: String,
                        ) = writable {
                baseCustomizations: List<ServiceRuntimePluginCustomization>,
            ): List<ServiceRuntimePluginCustomization> =
                baseCustomizations + listOf(AccountIdEndpointFeatureTrackerInterceptor(codegenContext))
        },
)

private class AccountIdEndpointFeatureTrackerInterceptor(codegenContext: ClientCodegenContext) :
    ServiceRuntimePluginCustomization() {
    override fun section(section: ServiceRuntimePluginSection) =
        writable {
            if (section is ServiceRuntimePluginSection.RegisterRuntimeComponents) {
                section.registerInterceptor(this) {
                    rustTemplate(
                                """
                                match cfg
                                    .load::<#{AccountIdEndpointMode}>()
                                    .cloned()
                                    .unwrap_or_default()
                                {
                                    #{AccountIdEndpointMode}::Preferred => {
                                        $configBag.interceptor_state().store_append(#{AwsSdkFeature}::AccountIdModePreferred);
                                    }
                                    #{AccountIdEndpointMode}::Required => {
                                        $configBag.interceptor_state().store_append(#{AwsSdkFeature}::AccountIdModeRequired);
                                    }
                                    #{AccountIdEndpointMode}::Disabled => {
                                        $configBag.interceptor_state().store_append(#{AwsSdkFeature}::AccountIdModeDisabled);
                        "#{Interceptor}::default()",
                        "Interceptor" to
                            RuntimeType.forInlineDependency(
                                InlineAwsDependency.forRustFile(
                                    "account_id_endpoint",
                                ),
                            ).resolve("AccountIdEndpointFeatureTrackerInterceptor"),
                    )
                }
                                    otherwise => {
                                        #{tracing}::warn!(
                                            "Attempted to track an SDK feature for `{otherwise:?}`, which is not recognized in the current version of the SDK. \
                                            Consider upgrading to the latest version to ensure that it is properly tracked."
                                        );
            }
        }
                                """,
                                *codegenScope,
                            )
}
                    },
                )
        },
)
+55 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_runtime::sdk_feature::AwsSdkFeature;
use aws_smithy_runtime_api::{
    box_error::BoxError,
    client::interceptors::{context::BeforeSerializationInterceptorContextRef, Intercept},
};
use aws_smithy_types::config_bag::ConfigBag;
use aws_types::endpoint_config::AccountIdEndpointMode;

// Interceptor that tracks AWS SDK features for the account based endpoints.
#[derive(Debug, Default)]
pub(crate) struct AccountIdEndpointFeatureTrackerInterceptor;

impl Intercept for AccountIdEndpointFeatureTrackerInterceptor {
    fn name(&self) -> &'static str {
        "AccountIdEndpointFeatureTrackerInterceptor"
    }

    fn read_before_execution(
        &self,
        _context: &BeforeSerializationInterceptorContextRef<'_>,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        match cfg
            .load::<AccountIdEndpointMode>()
            .cloned()
            .unwrap_or_default()
        {
            AccountIdEndpointMode::Preferred => {
                cfg.interceptor_state()
                    .store_append(AwsSdkFeature::AccountIdModePreferred);
            }
            AccountIdEndpointMode::Required => {
                cfg.interceptor_state()
                    .store_append(AwsSdkFeature::AccountIdModeRequired);
            }
            AccountIdEndpointMode::Disabled => {
                cfg.interceptor_state()
                    .store_append(AwsSdkFeature::AccountIdModeDisabled);
            }
            otherwise => {
                ::tracing::warn!(
                    "Attempted to track an SDK feature for `{otherwise:?}`, which is not recognized in the current version of the SDK. \
                    Consider upgrading to the latest version to ensure that it is properly tracked."
                );
            }
        }

        Ok(())
    }
}
+4 −0
Original line number Diff line number Diff line
@@ -22,6 +22,10 @@
    unreachable_pub
)]

/// Supporting code for the account based endpoints.
#[allow(dead_code)]
pub mod account_id_endpoint;

/// Supporting code to determine auth scheme options based on the `authSchemes` endpoint list property.
#[allow(dead_code)]
pub mod endpoint_auth;
+0 −22
Original line number Diff line number Diff line
@@ -107,28 +107,6 @@ interface EndpointCustomization {
        codegenContext: ClientCodegenContext,
        params: String,
    ): Writable? = null

    /**
     * Allows tracking SDK features (generic SDK or AWS SDK) in the config bag.
     *
     * Example:
     * ```kotlin
     * override fun trackSdkFeatures(codegenContext: ClientCodegenContext, configBag: String): Writable? {
     *     return writable {
     *         rustTemplate("""
     *             $configBag.interceptor_state().store_put(#{SmithySdkFeature}::Foo);
     *         """,
     *         "SmithySdkFeature" to RuntimeType.smithyRuntime(codegenContext.runtimeConfig)
     *             .resolve("client::sdk_feature::SmithySdkFeature"),
     *         )
     *     }
     * }
     * ```
     */
    fun trackSdkFeatures(
        codegenContext: ClientCodegenContext,
        configBag: String,
    ): Writable? = null
}

/**
+0 −14
Original line number Diff line number Diff line
@@ -109,11 +109,7 @@ class EndpointParamsInterceptorGenerator(
                        #{param_setters}
                        .build()
                        .map_err(|err| #{ContextAttachedError}::new("endpoint params could not be built", err))?;

                    #{track_sdk_features:W}

                    cfg.interceptor_state().store_put(#{EndpointResolverParams}::new(params));

                    #{Ok}(())
                }
            }
@@ -127,7 +123,6 @@ class EndpointParamsInterceptorGenerator(
            "endpoint_prefix" to endpointPrefix(operationShape),
            "param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
            "jmespath_getters" to jmesPathGetters(operationShape),
            "track_sdk_features" to trackSdkFeatures(),
        )
    }

@@ -276,13 +271,4 @@ class EndpointParamsInterceptorGenerator(
                rust("cfg.interceptor_state().store_put(endpoint_prefix);")
            }
        }

    private fun trackSdkFeatures() =
        writable {
            endpointCustomizations.mapNotNull {
                it.trackSdkFeatures(codegenContext, "cfg")
            }.forEach {
                rust("#T", it)
            }
        }
}