Unverified Commit e8fcd45f authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Extract SdkConfig parts into decorators (#2140)

* Extract SdkConfig parts into decorators

* CR feedback

* Fix HTTP connector
parent c49d5961
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
    HttpConnectorDecorator(),
    AwsEndpointsStdLib(),
    AddFIPSDualStackDecorator(),
    GenericSmithySdkConfigSettings(),

    // Service specific decorators
    ApiGatewayDecorator(),
+23 −2
Original line number Diff line number Diff line
@@ -31,10 +31,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.util.extendIf
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class AwsEndpointDecorator : ClientCodegenDecorator {
    override val name: String = "AwsEndpoint"
@@ -78,7 +81,7 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
    ): List<ConfigCustomization> {
        return baseCustomizations.extendIf(codegenContext.getBuiltIn(Builtins.REGION) != null) {
        return baseCustomizations.extendIf(codegenContext.isRegionalized()) {
            AwsEndpointShimCustomization(codegenContext)
        } + SdkEndpointCustomization(
            codegenContext,
@@ -101,7 +104,7 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
            )
        }
        // generate a region converter if params has a region
        if (!epTypes.params.toList().any { it.builtIn == Builtins.REGION.builtIn }) {
        if (!codegenContext.isRegionalized()) {
            println("not generating a resolver for ${codegenContext.serviceShape}")
            return
        }
@@ -127,6 +130,21 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
        }
    }

    override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> {
        return codegenContext.isRegionalized().thenSingletonListOf {
            SdkConfigSection.create { section ->
                {
                    rust(
                        """
                        ${section.serviceConfigBuilder}.set_aws_endpoint_resolver(${section.sdkConfig}.endpoint_resolver().clone());
                        ${section.serviceConfigBuilder}.set_endpoint_url(${section.sdkConfig}.endpoint_url().map(|url|url.to_string()));
                        """,
                    )
                }
            }
        }
    }

    override fun endpointCustomizations(codegenContext: ClientCodegenContext): List<EndpointCustomization> {
        return listOf(
            object : EndpointCustomization {
@@ -150,6 +168,7 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
            "EndpointShim" to endpointShim,
            "aws_types" to AwsRuntimeType.awsTypes(runtimeConfig),
        )

        override fun section(section: ServiceConfig) = writable {
            when (section) {
                ServiceConfig.BuilderImpl -> rustTemplate(
@@ -267,3 +286,5 @@ class AwsEndpointDecorator : ClientCodegenDecorator {
        }
    }
}

fun ClientCodegenContext.isRegionalized() = getBuiltIn(Builtins.REGION) != null
+14 −0
Original line number Diff line number Diff line
@@ -16,8 +16,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection

@@ -46,6 +48,15 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
    ): List<LibRsCustomization> {
        return baseCustomizations + PubUseCredentials(codegenContext.runtimeConfig)
    }

    override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> =
        listOf(
            SdkConfigSection.create { section ->
                writable {
                    rust("${section.serviceConfigBuilder}.set_credentials_provider(${section.sdkConfig}.credentials_provider().cloned());")
                }
            },
        )
}

/**
@@ -64,6 +75,7 @@ class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomizati
                """pub(crate) credentials_provider: #{provider}::SharedCredentialsProvider,""",
                *codegenScope,
            )

            ServiceConfig.ConfigImpl -> rustTemplate(
                """
                /// Returns the credentials provider.
@@ -73,8 +85,10 @@ class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomizati
                """,
                *codegenScope,
            )

            ServiceConfig.BuilderStruct ->
                rustTemplate("credentials_provider: Option<#{provider}::SharedCredentialsProvider>,", *codegenScope)

            ServiceConfig.BuilderImpl -> {
                rustTemplate(
                    """
+18 −0
Original line number Diff line number Diff line
@@ -21,12 +21,15 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.extendIf
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

/* Example Generated Code */
/*
@@ -106,6 +109,21 @@ class RegionDecorator : ClientCodegenDecorator {
        return baseCustomizations.extendIf(usesRegion(codegenContext)) { PubUseRegion(codegenContext.runtimeConfig) }
    }

    override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> {
        return usesRegion(codegenContext).thenSingletonListOf {
            SdkConfigSection.create { section ->
                {
                    rust(
                        """
                        ${section.serviceConfigBuilder} =
                             ${section.serviceConfigBuilder}.region(${section.sdkConfig}.region().cloned());
                        """,
                    )
                }
            }
        }
    }

    override fun endpointCustomizations(codegenContext: ClientCodegenContext): List<EndpointCustomization> {
        if (!usesRegion(codegenContext)) {
            return listOf()
+59 −17
Original line number Diff line number Diff line
@@ -5,18 +5,64 @@

package software.amazon.smithy.rustsdk

import software.amazon.smithy.rulesengine.language.syntax.parameters.Builtins
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocSection
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section

/**
 * Section enabling linkage between `SdkConfig` and <service>::Config
 */
object SdkConfigSection : AdHocSection<SdkConfigSection.CopySdkConfigToClientConfig>("SdkConfig") {
    /**
     * [sdkConfig]: A reference to the SDK config struct
     * [serviceConfigBuilder]: A reference (owned) to the `<service>::config::Builder` struct.
     *
     * Each invocation of this section MUST be a complete statement (ending with a semicolon), e.g:
     * ```
     * rust("${section.serviceConfigBuilder}.set_foo(${section.sdkConfig}.foo());")
     * ```
     */
    data class CopySdkConfigToClientConfig(val sdkConfig: String, val serviceConfigBuilder: String) :
        Section("CopyConfig")
}

/**
 * SdkConfig -> <service>::Config for settings that come from generic smithy
 */
class GenericSmithySdkConfigSettings : ClientCodegenDecorator {
    override val name: String = "GenericSmithySdkConfigSettings"
    override val order: Byte = 0

    override fun extraSections(codegenContext: ClientCodegenContext): List<Pair<AdHocSection<*>, (Section) -> Writable>> =
        listOf(
            SdkConfigSection.create { section ->
                writable {
                    rust(
                        """
                        // resiliency
                        ${section.serviceConfigBuilder}.set_retry_config(${section.sdkConfig}.retry_config().cloned());
                        ${section.serviceConfigBuilder}.set_timeout_config(${section.sdkConfig}.timeout_config().cloned());
                        ${section.serviceConfigBuilder}.set_sleep_impl(${section.sdkConfig}.sleep_impl());

                        ${section.serviceConfigBuilder}.set_http_connector(${section.sdkConfig}.http_connector().cloned());

                        """,
                    )
                }
            },
        )
}

/**
 * Adds functionality for constructing `<service>::Config` objects from `aws_types::SdkConfig`s
@@ -39,27 +85,15 @@ class SdkConfigDecorator : ClientCodegenDecorator {
        val codegenScope = arrayOf(
            "SdkConfig" to AwsRuntimeType.awsTypes(codegenContext.runtimeConfig).resolve("sdk_config::SdkConfig"),
        )
        val regionalizedBits = writable {
            if (codegenContext.getBuiltIn(Builtins.REGION) != null) {
                rust("builder = builder.region(input.region().cloned());")
                rust("builder.set_aws_endpoint_resolver(input.endpoint_resolver().clone());")
            }
        }
        rustCrate.withModule(RustModule.Config) {
            // !!NOTE!! As more items are added to aws_types::SdkConfig, use them here to configure the config builder
            rustTemplate(
                """
                impl From<&#{SdkConfig}> for Builder {
                    fn from(input: &#{SdkConfig}) -> Self {
                        let mut builder = Builder::default();
                        #{regionalized}
                        builder.set_endpoint_url(input.endpoint_url().map(|url|url.to_string()));
                        builder.set_retry_config(input.retry_config().cloned());
                        builder.set_timeout_config(input.timeout_config().cloned());
                        builder.set_sleep_impl(input.sleep_impl());
                        builder.set_credentials_provider(input.credentials_provider().cloned());
                        builder.set_app_name(input.app_name().cloned());
                        builder.set_http_connector(input.http_connector().cloned());
                        #{augmentBuilder}


                        builder
                    }
                }
@@ -70,8 +104,16 @@ class SdkConfigDecorator : ClientCodegenDecorator {
                    }
                }
                """,
                "augmentBuilder" to codegenContext.rootDecorator.extraSections(codegenContext)
                    .filter { (t, _) -> t is SdkConfigSection }.map { (_, sectionWriter) ->
                        sectionWriter(
                            SdkConfigSection.CopySdkConfigToClientConfig(
                                sdkConfig = "input",
                                serviceConfigBuilder = "builder",
                            ),
                        )
                    }.join("\n"),
                *codegenScope,
                "regionalized" to regionalizedBits,
            )
        }
    }
Loading