Unverified Commit 255127a1 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Make the SDK invocation ID generator configurable (#2860)

This PR makes it possible to configure the invocation ID generator in
config.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 25abe5a8
Loading
Loading
Loading
Loading
+12 −12
Original line number Diff line number Diff line
@@ -6,13 +6,13 @@
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
use aws_smithy_runtime_api::client::interceptors::Interceptor;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use fastrand::Rng;
use http::{HeaderName, HeaderValue};
use std::fmt::Debug;
use std::sync::{Arc, Mutex};

use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use fastrand::Rng;
use std::sync::Mutex;
#[cfg(feature = "test-util")]
pub use test_util::{NoInvocationIdGenerator, PredefinedInvocationIdGenerator};

@@ -27,23 +27,23 @@ pub trait InvocationIdGenerator: Debug + Send + Sync {
}

/// Dynamic dispatch implementation of [`InvocationIdGenerator`]
#[derive(Debug)]
pub struct DynInvocationIdGenerator(Box<dyn InvocationIdGenerator>);
#[derive(Clone, Debug)]
pub struct SharedInvocationIdGenerator(Arc<dyn InvocationIdGenerator>);

impl DynInvocationIdGenerator {
    /// Creates a new [`DynInvocationIdGenerator`].
impl SharedInvocationIdGenerator {
    /// Creates a new [`SharedInvocationIdGenerator`].
    pub fn new(gen: impl InvocationIdGenerator + 'static) -> Self {
        Self(Box::new(gen))
        Self(Arc::new(gen))
    }
}

impl InvocationIdGenerator for DynInvocationIdGenerator {
impl InvocationIdGenerator for SharedInvocationIdGenerator {
    fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
        self.0.generate()
    }
}

impl Storable for DynInvocationIdGenerator {
impl Storable for SharedInvocationIdGenerator {
    type Storer = StoreReplace<Self>;
}

@@ -100,7 +100,7 @@ impl Interceptor for InvocationIdInterceptor {
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let gen = cfg
            .load::<DynInvocationIdGenerator>()
            .load::<SharedInvocationIdGenerator>()
            .map(|gen| gen as &dyn InvocationIdGenerator)
            .unwrap_or(&self.default);
        if let Some(id) = gen.generate()? {
@@ -264,7 +264,7 @@ mod tests {

        let mut cfg = ConfigBag::base();
        let mut layer = Layer::new("test");
        layer.store_put(DynInvocationIdGenerator::new(
        layer.store_put(SharedInvocationIdGenerator::new(
            PredefinedInvocationIdGenerator::new(vec![InvocationId::new(
                "the-best-invocation-id".into(),
            )]),
+71 −0
Original line number Diff line number Diff line
@@ -9,14 +9,19 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.letIf

class InvocationIdDecorator : ClientCodegenDecorator {
    override val name: String get() = "InvocationIdDecorator"
    override val order: Byte get() = 0

    override fun serviceRuntimePluginCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ServiceRuntimePluginCustomization>,
@@ -24,6 +29,14 @@ class InvocationIdDecorator : ClientCodegenDecorator {
        baseCustomizations.letIf(codegenContext.smithyRuntimeMode.generateOrchestrator) {
            it + listOf(InvocationIdRuntimePluginCustomization(codegenContext))
        }

    override fun configCustomizations(
        codegenContext: ClientCodegenContext,
        baseCustomizations: List<ConfigCustomization>,
    ): List<ConfigCustomization> =
        baseCustomizations.letIf(codegenContext.smithyRuntimeMode.generateOrchestrator) {
            it + listOf(InvocationIdConfigCustomization(codegenContext))
        }
}

private class InvocationIdRuntimePluginCustomization(
@@ -43,3 +56,61 @@ private class InvocationIdRuntimePluginCustomization(
        }
    }
}

const val GENERATOR_DOCS: String =
    "The invocation ID generator generates ID values for the `amz-sdk-invocation-id` header. " +
        "By default, this will be a random UUID. Overriding it may be useful in tests that " +
        "examine the HTTP request and need to be deterministic."

private class InvocationIdConfigCustomization(
    codegenContext: ClientCodegenContext,
) : ConfigCustomization() {
    private val awsRuntime = AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig)
    private val codegenScope = arrayOf(
        *preludeScope,
        "InvocationIdGenerator" to awsRuntime.resolve("invocation_id::InvocationIdGenerator"),
        "SharedInvocationIdGenerator" to awsRuntime.resolve("invocation_id::SharedInvocationIdGenerator"),
    )

    override fun section(section: ServiceConfig): Writable = writable {
        when (section) {
            is ServiceConfig.BuilderImpl -> {
                docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS")
                rustTemplate(
                    """
                    pub fn invocation_id_generator(mut self, gen: impl #{InvocationIdGenerator} + 'static) -> Self {
                        self.set_invocation_id_generator(#{Some}(#{SharedInvocationIdGenerator}::new(gen)));
                        self
                    }
                    """,
                    *codegenScope,
                )

                docs("Overrides the default invocation ID generator.\n\n$GENERATOR_DOCS")
                rustTemplate(
                    """
                    pub fn set_invocation_id_generator(&mut self, gen: #{Option}<#{SharedInvocationIdGenerator}>) -> &mut Self {
                        self.config.store_or_unset(gen);
                        self
                    }
                    """,
                    *codegenScope,
                )
            }

            is ServiceConfig.ConfigImpl -> {
                docs("Returns the invocation ID generator if one was given in config.\n\n$GENERATOR_DOCS")
                rustTemplate(
                    """
                    pub fn invocation_id_generator(&self) -> #{Option}<#{SharedInvocationIdGenerator}> {
                        self.config.load::<#{SharedInvocationIdGenerator}>().cloned()
                    }
                    """,
                    *codegenScope,
                )
            }

            else -> {}
        }
    }
}
+36 −34
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rustsdk.awsSdkIntegrationTest

class SdkCodegenIntegrationTest {
    companion object {
        val model = """
            namespace test

@@ -43,6 +44,7 @@ class SdkCodegenIntegrationTest {
                output: SomeOutput
            }
        """.asSmithyModel()
    }

    @Test
    fun smokeTestSdkCodegen() {
+61 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import SdkCodegenIntegrationTest
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest

class InvocationIdDecoratorTest {
    @Test
    fun customInvocationIdGenerator() {
        awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { context, rustCrate ->
            val rc = context.runtimeConfig
            val moduleName = context.moduleUseName()
            rustCrate.integrationTest("custom_invocation_id") {
                rustTemplate(
                    """
                    ##[#{tokio}::test]
                    async fn custom_invocation_id() {
                        ##[derive(::std::fmt::Debug)]
                        struct TestIdGen;
                        impl #{InvocationIdGenerator} for TestIdGen {
                            fn generate(&self) -> #{Result}<#{Option}<#{InvocationId}>, #{BoxError}> {
                                #{Ok}(#{Some}(#{InvocationId}::new("custom".into())))
                            }
                        }

                        let (conn, rx) = #{capture_request}(None);
                        let config = $moduleName::Config::builder()
                            .http_connector(conn)
                            .invocation_id_generator(TestIdGen)
                            .build();
                        assert!(config.invocation_id_generator().is_some());

                        let client = $moduleName::Client::from_conf(config);

                        let _ = client.some_operation().send().await;
                        let request = rx.expect_request();
                        assert_eq!("custom", request.headers().get("amz-sdk-invocation-id").unwrap());
                    }
                    """,
                    *preludeScope,
                    "tokio" to CargoDependency.Tokio.toType(),
                    "InvocationIdGenerator" to AwsRuntimeType.awsRuntime(rc)
                        .resolve("invocation_id::InvocationIdGenerator"),
                    "InvocationId" to AwsRuntimeType.awsRuntime(rc)
                        .resolve("invocation_id::InvocationId"),
                    "BoxError" to RuntimeType.boxError(rc),
                    "capture_request" to RuntimeType.captureRequest(rc),
                )
            }
        }
    }
}