Unverified Commit 9bfe936f authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Make the SDK ad hoc tests pass against the orchestrator (#2708)

## Motivation and Context
This PR refactors the client protocol test generator machinery to use a
client instead of calling `make_operation` directly, and then fixes the
ad hoc tests for the orchestrator.

The ad hoc tests revealed that overriding the signing region/service via
endpoint config was lost when porting SigV4 signing to the orchestrator,
so this PR updates the SigV4 `HttpRequestSigner` implementation to
restore this functionality. It is doing this in the signer directly
rather than via an interceptor since it should only run this logic when
SigV4 is the selected auth scheme.

Other notable changes:
- Adds `--no-fail-fast` arg to `cargoTest` targets so that all Rust
tests run in CI rather than stopping on the first failure
- Changes `EndpointResolver::resolve_and_apply_endpoint` to just
`resolve_endpoint` so that the orchestrator can place the endpoint
config into the request state, which is required for the signer to make
use of it
- Adds a `set_region` method to SDK service configs
- Deletes the API Gateway model and integration test from the SDK smoke
test since it is covered by the ad hoc tests
- Adds a comment explaining where the API Gateway model comes from in
the ad hoc tests
- Adds a `smithy.runtime.mode` Gradle property to `aws:sdk` and
`aws:sdk-adhoc-test` to trivially switch between middleware and
orchestrator when testing/generating locally

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 41774b84
Loading
Loading
Loading
Loading
+197 −18
Original line number Original line Diff line number Diff line
@@ -11,12 +11,18 @@ pub mod sigv4 {
        SignableRequest, SignatureLocation, SigningParams, SigningSettings,
        SignableRequest, SignatureLocation, SigningParams, SigningSettings,
        UriPathNormalizationMode,
        UriPathNormalizationMode,
    };
    };
    use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
    use aws_smithy_runtime_api::client::auth::{
        AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner,
    };
    use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
    use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
    use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
    use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
    use aws_smithy_runtime_api::config_bag::ConfigBag;
    use aws_smithy_runtime_api::config_bag::ConfigBag;
    use aws_types::region::SigningRegion;
    use aws_smithy_types::Document;
    use aws_types::region::{Region, SigningRegion};
    use aws_types::SigningService;
    use aws_types::SigningService;
    use std::borrow::Cow;
    use std::error::Error as StdError;
    use std::fmt;
    use std::time::{Duration, SystemTime};
    use std::time::{Duration, SystemTime};


    const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \
    const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \
@@ -25,6 +31,53 @@ pub mod sigv4 {
    /// Auth scheme ID for SigV4.
    /// Auth scheme ID for SigV4.
    pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4");
    pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4");


    struct EndpointAuthSchemeConfig {
        signing_region_override: Option<SigningRegion>,
        signing_service_override: Option<SigningService>,
    }

    #[derive(Debug)]
    enum SigV4SigningError {
        MissingOperationSigningConfig,
        MissingSigningRegion,
        MissingSigningService,
        WrongIdentityType(Identity),
        BadTypeInEndpointAuthSchemeConfig(&'static str),
    }

    impl fmt::Display for SigV4SigningError {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            use SigV4SigningError::*;
            let mut w = |s| f.write_str(s);
            match self {
                MissingOperationSigningConfig => w("missing operation signing config for SigV4"),
                MissingSigningRegion => w("missing signing region for SigV4 signing"),
                MissingSigningService => w("missing signing service for SigV4 signing"),
                WrongIdentityType(identity) => {
                    write!(f, "wrong identity type for SigV4: {identity:?}")
                }
                BadTypeInEndpointAuthSchemeConfig(field_name) => {
                    write!(
                        f,
                        "unexpected type for `{field_name}` in endpoint auth scheme config",
                    )
                }
            }
        }
    }

    impl StdError for SigV4SigningError {
        fn source(&self) -> Option<&(dyn StdError + 'static)> {
            match self {
                Self::MissingOperationSigningConfig => None,
                Self::MissingSigningRegion => None,
                Self::MissingSigningService => None,
                Self::WrongIdentityType(_) => None,
                Self::BadTypeInEndpointAuthSchemeConfig(_) => None,
            }
        }
    }

    /// SigV4 auth scheme.
    /// SigV4 auth scheme.
    #[derive(Debug, Default)]
    #[derive(Debug, Default)]
    pub struct SigV4HttpAuthScheme {
    pub struct SigV4HttpAuthScheme {
@@ -111,9 +164,9 @@ pub mod sigv4 {
    #[derive(Clone, Debug, PartialEq, Eq)]
    #[derive(Clone, Debug, PartialEq, Eq)]
    pub struct SigV4OperationSigningConfig {
    pub struct SigV4OperationSigningConfig {
        /// AWS Region to sign for.
        /// AWS Region to sign for.
        pub region: SigningRegion,
        pub region: Option<SigningRegion>,
        /// AWS Service to sign for.
        /// AWS Service to sign for.
        pub service: SigningService,
        pub service: Option<SigningService>,
        /// Signing options.
        /// Signing options.
        pub signing_options: SigningOptions,
        pub signing_options: SigningOptions,
    }
    }
@@ -165,7 +218,7 @@ pub mod sigv4 {
            credentials: &'a Credentials,
            credentials: &'a Credentials,
            operation_config: &'a SigV4OperationSigningConfig,
            operation_config: &'a SigV4OperationSigningConfig,
            request_timestamp: SystemTime,
            request_timestamp: SystemTime,
        ) -> SigningParams<'a> {
        ) -> Result<SigningParams<'a>, SigV4SigningError> {
            if let Some(expires_in) = settings.expires_in {
            if let Some(expires_in) = settings.expires_in {
                if let Some(creds_expires_time) = credentials.expiry() {
                if let Some(creds_expires_time) = credentials.expiry() {
                    let presigned_expires_time = request_timestamp + expires_in;
                    let presigned_expires_time = request_timestamp + expires_in;
@@ -178,12 +231,75 @@ pub mod sigv4 {
            let mut builder = SigningParams::builder()
            let mut builder = SigningParams::builder()
                .access_key(credentials.access_key_id())
                .access_key(credentials.access_key_id())
                .secret_key(credentials.secret_access_key())
                .secret_key(credentials.secret_access_key())
                .region(operation_config.region.as_ref())
                .region(
                .service_name(operation_config.service.as_ref())
                    operation_config
                        .region
                        .as_ref()
                        .ok_or(SigV4SigningError::MissingSigningRegion)?
                        .as_ref(),
                )
                .service_name(
                    operation_config
                        .service
                        .as_ref()
                        .ok_or(SigV4SigningError::MissingSigningService)?
                        .as_ref(),
                )
                .time(request_timestamp)
                .time(request_timestamp)
                .settings(settings);
                .settings(settings);
            builder.set_security_token(credentials.session_token());
            builder.set_security_token(credentials.session_token());
            builder.build().expect("all required fields set")
            Ok(builder.build().expect("all required fields set"))
        }

        fn extract_operation_config<'a>(
            auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'a>,
            config_bag: &'a ConfigBag,
        ) -> Result<Cow<'a, SigV4OperationSigningConfig>, SigV4SigningError> {
            let operation_config = config_bag
                .get::<SigV4OperationSigningConfig>()
                .ok_or(SigV4SigningError::MissingOperationSigningConfig)?;

            let EndpointAuthSchemeConfig {
                signing_region_override,
                signing_service_override,
            } = Self::extract_endpoint_auth_scheme_config(auth_scheme_endpoint_config)?;

            match (signing_region_override, signing_service_override) {
                (None, None) => Ok(Cow::Borrowed(operation_config)),
                (region, service) => {
                    let mut operation_config = operation_config.clone();
                    if region.is_some() {
                        operation_config.region = region;
                    }
                    if service.is_some() {
                        operation_config.service = service;
                    }
                    Ok(Cow::Owned(operation_config))
                }
            }
        }

        fn extract_endpoint_auth_scheme_config(
            endpoint_config: AuthSchemeEndpointConfig<'_>,
        ) -> Result<EndpointAuthSchemeConfig, SigV4SigningError> {
            let (mut signing_region_override, mut signing_service_override) = (None, None);
            if let Some(config) = endpoint_config.config().and_then(Document::as_object) {
                use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType;
                signing_region_override = match config.get("signingRegion") {
                    Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
                    None => None,
                    _ => return Err(UnexpectedType("signingRegion")),
                };
                signing_service_override = match config.get("signingName") {
                    Some(Document::String(s)) => Some(SigningService::from(s.to_string())),
                    None => None,
                    _ => return Err(UnexpectedType("signingName")),
                };
            }
            Ok(EndpointAuthSchemeConfig {
                signing_region_override,
                signing_service_override,
            })
        }
        }
    }
    }


@@ -192,11 +308,11 @@ pub mod sigv4 {
            &self,
            &self,
            request: &mut HttpRequest,
            request: &mut HttpRequest,
            identity: &Identity,
            identity: &Identity,
            auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
            config_bag: &ConfigBag,
            config_bag: &ConfigBag,
        ) -> Result<(), BoxError> {
        ) -> Result<(), BoxError> {
            let operation_config = config_bag
            let operation_config =
                .get::<SigV4OperationSigningConfig>()
                Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?;
                .ok_or("missing operation signing config for SigV4")?;
            let request_time = config_bag.request_time().unwrap_or_default().system_time();
            let request_time = config_bag.request_time().unwrap_or_default().system_time();


            let credentials = if let Some(creds) = identity.data::<Credentials>() {
            let credentials = if let Some(creds) = identity.data::<Credentials>() {
@@ -205,12 +321,12 @@ pub mod sigv4 {
                tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials");
                tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials");
                return Ok(());
                return Ok(());
            } else {
            } else {
                return Err(format!("wrong identity type for SigV4: {identity:?}").into());
                return Err(SigV4SigningError::WrongIdentityType(identity.clone()).into());
            };
            };


            let settings = Self::settings(operation_config);
            let settings = Self::settings(&operation_config);
            let signing_params =
            let signing_params =
                Self::signing_params(settings, credentials, operation_config, request_time);
                Self::signing_params(settings, credentials, &operation_config, request_time)?;


            let (signing_instructions, _signature) = {
            let (signing_instructions, _signature) = {
                // A body that is already in memory can be signed directly. A body that is not in memory
                // A body that is already in memory can be signed directly. A body that is not in memory
@@ -250,6 +366,9 @@ pub mod sigv4 {
        use super::*;
        use super::*;
        use aws_credential_types::Credentials;
        use aws_credential_types::Credentials;
        use aws_sigv4::http_request::SigningSettings;
        use aws_sigv4::http_request::SigningSettings;
        use aws_types::region::SigningRegion;
        use aws_types::SigningService;
        use std::collections::HashMap;
        use std::time::{Duration, SystemTime};
        use std::time::{Duration, SystemTime};
        use tracing_test::traced_test;
        use tracing_test::traced_test;


@@ -270,8 +389,8 @@ pub mod sigv4 {
                "test",
                "test",
            );
            );
            let operation_config = SigV4OperationSigningConfig {
            let operation_config = SigV4OperationSigningConfig {
                region: SigningRegion::from_static("test"),
                region: Some(SigningRegion::from_static("test")),
                service: SigningService::from_static("test"),
                service: Some(SigningService::from_static("test")),
                signing_options: SigningOptions {
                signing_options: SigningOptions {
                    double_uri_encode: true,
                    double_uri_encode: true,
                    content_sha256_header: true,
                    content_sha256_header: true,
@@ -283,14 +402,74 @@ pub mod sigv4 {
                    payload_override: None,
                    payload_override: None,
                },
                },
            };
            };
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
                .unwrap();
            assert!(!logs_contain(EXPIRATION_WARNING));
            assert!(!logs_contain(EXPIRATION_WARNING));


            let mut settings = SigningSettings::default();
            let mut settings = SigningSettings::default();
            settings.expires_in = Some(creds_expire_in + Duration::from_secs(10));
            settings.expires_in = Some(creds_expire_in + Duration::from_secs(10));


            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now)
                .unwrap();
            assert!(logs_contain(EXPIRATION_WARNING));
            assert!(logs_contain(EXPIRATION_WARNING));
        }
        }

        #[test]
        fn endpoint_config_overrides_region_and_service() {
            let mut cfg = ConfigBag::base();
            cfg.put(SigV4OperationSigningConfig {
                region: Some(SigningRegion::from(Region::new("override-this-region"))),
                service: Some(SigningService::from_static("override-this-service")),
                signing_options: Default::default(),
            });
            let config = Document::Object({
                let mut out = HashMap::new();
                out.insert("name".to_string(), "sigv4".to_string().into());
                out.insert(
                    "signingName".to_string(),
                    "qldb-override".to_string().into(),
                );
                out.insert(
                    "signingRegion".to_string(),
                    "us-east-override".to_string().into(),
                );
                out
            });
            let config = AuthSchemeEndpointConfig::new(Some(&config));

            let result =
                SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");

            assert_eq!(
                result.region,
                Some(SigningRegion::from(Region::new("us-east-override")))
            );
            assert_eq!(
                result.service,
                Some(SigningService::from_static("qldb-override"))
            );
            assert!(matches!(result, Cow::Owned(_)));
        }

        #[test]
        fn endpoint_config_supports_fallback_when_region_or_service_are_unset() {
            let mut cfg = ConfigBag::base();
            cfg.put(SigV4OperationSigningConfig {
                region: Some(SigningRegion::from(Region::new("us-east-1"))),
                service: Some(SigningService::from_static("qldb")),
                signing_options: Default::default(),
            });
            let config = AuthSchemeEndpointConfig::empty();

            let result =
                SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success");

            assert_eq!(
                result.region,
                Some(SigningRegion::from(Region::new("us-east-1")))
            );
            assert_eq!(result.service, Some(SigningService::from_static("qldb")));
            assert!(matches!(result, Cow::Borrowed(_)));
        }
    }
    }
}
}
+4 −0
Original line number Original line Diff line number Diff line
@@ -37,6 +37,8 @@ dependencies {
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
}
}


fun getSmithyRuntimeMode(): String = properties.get("smithy.runtime.mode") ?: "middleware"

val allCodegenTests = listOf(
val allCodegenTests = listOf(
    CodegenTest(
    CodegenTest(
        "com.amazonaws.apigateway#BackplaneControlService",
        "com.amazonaws.apigateway#BackplaneControlService",
@@ -46,6 +48,7 @@ val allCodegenTests = listOf(
            ,
            ,
            "codegen": {
            "codegen": {
                "includeFluentClient": false,
                "includeFluentClient": false,
                "enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
            },
            },
            "customizationConfig": {
            "customizationConfig": {
                "awsSdk": {
                "awsSdk": {
@@ -62,6 +65,7 @@ val allCodegenTests = listOf(
            ,
            ,
            "codegen": {
            "codegen": {
                "includeFluentClient": false,
                "includeFluentClient": false,
                "enableNewSmithyRuntime": "${getSmithyRuntimeMode()}"
            },
            },
            "customizationConfig": {
            "customizationConfig": {
                "awsSdk": {
                "awsSdk": {
+8 −0
Original line number Original line Diff line number Diff line
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// The API Gateway model is coming from Smithy's protocol tests, and includes an `Accept` header test:
// https://github.com/awslabs/smithy/blob/2f6553ff39e6bba9edc644ef5832661821785319/smithy-aws-protocol-tests/model/restJson1/services/apigateway.smithy#L30-L43

$version: "1.0"
$version: "1.0"


namespace com.amazonaws.apigateway
namespace com.amazonaws.apigateway


use smithy.rules#endpointRuleSet
use smithy.rules#endpointRuleSet

// Add an endpoint ruleset to the Smithy protocol test API Gateway model so that the code generator doesn't fail
apply BackplaneControlService @endpointRuleSet({
apply BackplaneControlService @endpointRuleSet({
    "version": "1.0",
    "version": "1.0",
    "rules": [{
    "rules": [{
+25 −0
Original line number Original line Diff line number Diff line
@@ -14,6 +14,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.NoClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.NoClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
@@ -96,6 +98,29 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
            }
            }
        }
        }
    }
    }

    override fun protocolTestGenerator(
        codegenContext: ClientCodegenContext,
        baseGenerator: ProtocolTestGenerator,
    ): ProtocolTestGenerator = DefaultProtocolTestGenerator(
        codegenContext,
        baseGenerator.protocolSupport,
        baseGenerator.operationShape,
        renderClientCreation = { params ->
            rustTemplate(
                """
                // If the test case was missing endpoint parameters, default a region so it doesn't fail
                let mut ${params.configBuilderName} = ${params.configBuilderName};
                if ${params.configBuilderName}.region.is_none() {
                    ${params.configBuilderName}.set_region(Some(crate::config::Region::new("us-east-1")));
                }
                let config = ${params.configBuilderName}.http_connector(${params.connectorName}).build();
                let ${params.clientName} = #{Client}::from_conf(config);
                """,
                "Client" to ClientRustModule.root.toType().resolve("Client"),
            )
        },
    )
}
}


private class AwsFluentClientExtensions(types: Types) {
private class AwsFluentClientExtensions(types: Types) {
+7 −1
Original line number Original line Diff line number Diff line
@@ -171,7 +171,7 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
            )
            )


            ServiceConfig.BuilderStruct ->
            ServiceConfig.BuilderStruct ->
                rustTemplate("region: Option<#{Region}>,", *codegenScope)
                rustTemplate("pub(crate) region: Option<#{Region}>,", *codegenScope)


            ServiceConfig.BuilderImpl ->
            ServiceConfig.BuilderImpl ->
                rustTemplate(
                rustTemplate(
@@ -191,6 +191,12 @@ class RegionProviderConfig(codegenContext: CodegenContext) : ConfigCustomization
                        self.region = region.into();
                        self.region = region.into();
                        self
                        self
                    }
                    }

                    /// Sets the AWS region to use when making requests.
                    pub fn set_region(&mut self, region: Option<#{Region}>) -> &mut Self {
                        self.region = region;
                        self
                    }
                    """,
                    """,
                    *codegenScope,
                    *codegenScope,
                )
                )
Loading