Unverified Commit b1ce5e3c authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Recalculate signing time for every retry attempt (#2643)

## Motivation and Context
When SigV4 signing was ported over to the orchestrator, the request time
got calculated once before the retry loop, which was incorrect. This PR
moves that request time calculation into the request signer so that it
happens on every attempt (unless there is a configured request time
override).

This PR also refactors auth to use the `ConfigBag` instead of a separate
signing properties `PropertyBag`.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 5a8fff7a
Loading
Loading
Loading
Loading
+15 −19
Original line number Diff line number Diff line
@@ -11,11 +11,10 @@ pub mod sigv4 {
        SignableRequest, SignatureLocation, SigningParams, SigningSettings,
        UriPathNormalizationMode,
    };
    use aws_smithy_http::property_bag::PropertyBag;
    use aws_smithy_runtime_api::client::auth::{AuthSchemeId, HttpAuthScheme, HttpRequestSigner};
    use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers};
    use aws_smithy_runtime_api::client::orchestrator::{
        BoxError, HttpAuthScheme, HttpRequest, HttpRequestSigner,
    };
    use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest};
    use aws_smithy_runtime_api::config_bag::ConfigBag;
    use aws_types::region::SigningRegion;
    use aws_types::SigningService;
    use std::time::{Duration, SystemTime};
@@ -24,7 +23,7 @@ pub mod sigv4 {
        `expires_in` duration because the credentials used to sign it will expire first.";

    /// Auth scheme ID for SigV4.
    pub const SCHEME_ID: &str = "sigv4";
    pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4");

    /// SigV4 auth scheme.
    #[derive(Debug, Default)]
@@ -40,7 +39,7 @@ pub mod sigv4 {
    }

    impl HttpAuthScheme for SigV4HttpAuthScheme {
        fn scheme_id(&self) -> &'static str {
        fn scheme_id(&self) -> AuthSchemeId {
            SCHEME_ID
        }

@@ -88,8 +87,6 @@ pub mod sigv4 {
        pub signing_optional: bool,
        /// Optional expiration (for presigning)
        pub expires_in: Option<Duration>,
        /// Timestamp to sign with.
        pub request_timestamp: SystemTime,
    }

    impl Default for SigningOptions {
@@ -103,7 +100,6 @@ pub mod sigv4 {
                signature_type: HttpSignatureType::HttpRequestHeaders,
                signing_optional: false,
                expires_in: None,
                request_timestamp: SystemTime::now(),
            }
        }
    }
@@ -168,11 +164,11 @@ pub mod sigv4 {
            settings: SigningSettings,
            credentials: &'a Credentials,
            operation_config: &'a SigV4OperationSigningConfig,
            request_timestamp: SystemTime,
        ) -> SigningParams<'a> {
            if let Some(expires_in) = settings.expires_in {
                if let Some(creds_expires_time) = credentials.expiry() {
                    let presigned_expires_time =
                        operation_config.signing_options.request_timestamp + expires_in;
                    let presigned_expires_time = request_timestamp + expires_in;
                    if presigned_expires_time > creds_expires_time {
                        tracing::warn!(EXPIRATION_WARNING);
                    }
@@ -184,7 +180,7 @@ pub mod sigv4 {
                .secret_key(credentials.secret_access_key())
                .region(operation_config.region.as_ref())
                .service_name(operation_config.service.as_ref())
                .time(operation_config.signing_options.request_timestamp)
                .time(request_timestamp)
                .settings(settings);
            builder.set_security_token(credentials.session_token());
            builder.build().expect("all required fields set")
@@ -196,12 +192,12 @@ pub mod sigv4 {
            &self,
            request: &mut HttpRequest,
            identity: &Identity,
            // TODO(enableNewSmithyRuntime): should this be the config bag?
            signing_properties: &PropertyBag,
            config_bag: &ConfigBag,
        ) -> Result<(), BoxError> {
            let operation_config = signing_properties
            let operation_config = config_bag
                .get::<SigV4OperationSigningConfig>()
                .ok_or("missing operation signing config for SigV4")?;
            let request_time = config_bag.request_time().unwrap_or_default().system_time();

            let credentials = if let Some(creds) = identity.data::<Credentials>() {
                creds
@@ -213,7 +209,8 @@ pub mod sigv4 {
            };

            let settings = Self::settings(operation_config);
            let signing_params = Self::signing_params(settings, credentials, operation_config);
            let signing_params =
                Self::signing_params(settings, credentials, operation_config, request_time);

            let (signing_instructions, _signature) = {
                // A body that is already in memory can be signed directly. A body that is not in memory
@@ -283,17 +280,16 @@ pub mod sigv4 {
                    signature_type: HttpSignatureType::HttpRequestHeaders,
                    signing_optional: false,
                    expires_in: None,
                    request_timestamp: now,
                    payload_override: None,
                },
            };
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config);
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
            assert!(!logs_contain(EXPIRATION_WARNING));

            let mut settings = SigningSettings::default();
            settings.expires_in = Some(creds_expire_in + Duration::from_secs(10));

            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config);
            SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now);
            assert!(logs_contain(EXPIRATION_WARNING));
        }
    }
+2 −2
Original line number Diff line number Diff line
@@ -6,9 +6,9 @@
/// Credentials-based identity support.
pub mod credentials {
    use aws_credential_types::cache::SharedCredentialsCache;
    use aws_smithy_http::property_bag::PropertyBag;
    use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver};
    use aws_smithy_runtime_api::client::orchestrator::{BoxError, Future};
    use aws_smithy_runtime_api::config_bag::ConfigBag;

    /// Smithy identity resolver for AWS credentials.
    #[derive(Debug)]
@@ -24,7 +24,7 @@ pub mod credentials {
    }

    impl IdentityResolver for CredentialsIdentityResolver {
        fn resolve_identity(&self, _identity_properties: &PropertyBag) -> Future<Identity> {
        fn resolve_identity(&self, _config_bag: &ConfigBag) -> Future<Identity> {
            let cache = self.credentials_cache.clone();
            Future::new(Box::pin(async move {
                let credentials = cache.as_ref().provide_cached_credentials().await?;
+5 −8
Original line number Diff line number Diff line
@@ -102,10 +102,8 @@ private class AuthOperationRuntimePluginCustomization(private val codegenContext
        val runtimeApi = RuntimeType.smithyRuntimeApi(runtimeConfig)
        val awsRuntime = AwsRuntimeType.awsRuntime(runtimeConfig)
        arrayOf(
            "AuthOptionListResolver" to runtimeApi.resolve("client::auth::option_resolver::AuthOptionListResolver"),
            "HttpAuthOption" to runtimeApi.resolve("client::orchestrator::HttpAuthOption"),
            "StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
            "HttpSignatureType" to awsRuntime.resolve("auth::sigv4::HttpSignatureType"),
            "PropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::PropertyBag"),
            "SIGV4_SCHEME_ID" to awsRuntime.resolve("auth::sigv4::SCHEME_ID"),
            "SigV4OperationSigningConfig" to awsRuntime.resolve("auth::sigv4::SigV4OperationSigningConfig"),
            "SigningOptions" to awsRuntime.resolve("auth::sigv4::SigningOptions"),
@@ -136,16 +134,15 @@ private class AuthOperationRuntimePluginCustomization(private val codegenContext
                        signing_options.normalize_uri_path = $normalizeUrlPath;
                        signing_options.signing_optional = $signingOptional;
                        signing_options.payload_override = #{payload_override};
                        signing_options.request_timestamp = cfg.request_time().unwrap_or_default().system_time();

                        let mut sigv4_properties = #{PropertyBag}::new();
                        sigv4_properties.insert(#{SigV4OperationSigningConfig} {
                        ${section.configBagName}.put(#{SigV4OperationSigningConfig} {
                            region: signing_region,
                            service: signing_service,
                            signing_options,
                        });
                        let auth_option_resolver = #{AuthOptionListResolver}::new(
                            vec![#{HttpAuthOption}::new(#{SIGV4_SCHEME_ID}, std::sync::Arc::new(sigv4_properties))]
                        // TODO(enableNewSmithyRuntime): Make auth options additive in the config bag so that multiple codegen decorators can register them
                        let auth_option_resolver = #{StaticAuthOptionResolver}::new(
                            vec![#{SIGV4_SCHEME_ID}]
                        );
                        ${section.configBagName}.set_auth_option_resolver(auth_option_resolver);
                        """,
+37 −219
Original line number Diff line number Diff line
@@ -5,16 +5,13 @@

#[macro_use]
extern crate criterion;
use aws_credential_types::cache::{CredentialsCache, SharedCredentialsCache};
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_sdk_s3 as s3;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::test_connection::infallible_connection_fn;
use aws_smithy_http::endpoint::SharedEndpointResolver;
use aws_smithy_runtime_api::type_erasure::TypedBox;
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_runtime_api::config_bag::ConfigBag;
use criterion::Criterion;
use s3::operation::list_objects_v2::{ListObjectsV2Error, ListObjectsV2Input, ListObjectsV2Output};
use s3::endpoint::Params;

async fn middleware(client: &s3::Client) {
    client
@@ -26,41 +23,36 @@ async fn middleware(client: &s3::Client) {
        .expect("successful execution");
}

async fn orchestrator(
    connector: &DynConnector,
    endpoint_resolver: SharedEndpointResolver<s3::endpoint::Params>,
    credentials_cache: SharedCredentialsCache,
) {
    let service_runtime_plugin = orchestrator::ManualServiceRuntimePlugin {
        connector: connector.clone(),
        endpoint_resolver: endpoint_resolver.clone(),
        credentials_cache: credentials_cache.clone(),
    };
async fn orchestrator(client: &s3::Client) {
    struct FixupPlugin {
        region: String,
    }
    impl RuntimePlugin for FixupPlugin {
        fn configure(
            &self,
            cfg: &mut ConfigBag,
        ) -> Result<(), aws_smithy_runtime_api::client::runtime_plugin::BoxError> {
            let params_builder = Params::builder()
                .set_region(Some(self.region.clone()))
                .bucket("test-bucket");

    // TODO(enableNewSmithyRuntime): benchmark with `send_v2` directly once it works
    let runtime_plugins = aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugins::new()
        .with_client_plugin(service_runtime_plugin)
        .with_operation_plugin(aws_sdk_s3::operation::list_objects_v2::ListObjectsV2::new())
        .with_operation_plugin(orchestrator::ManualOperationRuntimePlugin);
    let input = ListObjectsV2Input::builder()
            cfg.put(params_builder);
            Ok(())
        }
    }
    let _output = client
        .list_objects_v2()
        .bucket("test-bucket")
        .prefix("prefix~")
        .build()
        .unwrap();
    let input = TypedBox::new(input).erase();
    let output = aws_smithy_runtime::client::orchestrator::invoke(input, &runtime_plugins)
        .send_v2_with_plugin(Some(FixupPlugin {
            region: client
                .conf()
                .region()
                .map(|c| c.as_ref().to_string())
                .unwrap(),
        }))
        .await
        .map_err(|err| {
            err.map_service_error(|err| {
                TypedBox::<ListObjectsV2Error>::assume_from(err)
                    .expect("correct error type")
                    .unwrap()
            })
        })
        .unwrap();
    TypedBox::<ListObjectsV2Output>::assume_from(output)
        .expect("correct output type")
        .unwrap();
        .expect("successful execution");
}

fn test_connection() -> DynConnector {
@@ -93,14 +85,18 @@ fn test_connection() -> DynConnector {
    })
}

fn middleware_bench(c: &mut Criterion) {
fn client() -> s3::Client {
    let conn = test_connection();
    let config = s3::Config::builder()
        .credentials_provider(s3::config::Credentials::for_tests())
        .region(s3::config::Region::new("us-east-1"))
        .http_connector(conn.clone())
        .build();
    let client = s3::Client::from_conf(config);
    s3::Client::from_conf(config)
}

fn middleware_bench(c: &mut Criterion) {
    let client = client();
    c.bench_function("middleware", move |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter(|| async { middleware(&client).await })
@@ -108,190 +104,12 @@ fn middleware_bench(c: &mut Criterion) {
}

fn orchestrator_bench(c: &mut Criterion) {
    let conn = test_connection();
    let endpoint_resolver = SharedEndpointResolver::new(s3::endpoint::DefaultResolver::new());
    let credentials_cache = SharedCredentialsCache::new(
        CredentialsCache::lazy()
            .create_cache(SharedCredentialsProvider::new(Credentials::for_tests())),
    );

    let client = client();
    c.bench_function("orchestrator", move |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter(|| async {
                orchestrator(&conn, endpoint_resolver.clone(), credentials_cache.clone()).await
            })
            .iter(|| async { orchestrator(&client).await })
    });
}

mod orchestrator {
    use aws_credential_types::cache::SharedCredentialsCache;
    use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
    use aws_runtime::recursion_detection::RecursionDetectionInterceptor;
    use aws_runtime::user_agent::UserAgentInterceptor;
    use aws_sdk_s3::config::Region;
    use aws_sdk_s3::endpoint::Params;
    use aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Input;
    use aws_smithy_client::erase::DynConnector;
    use aws_smithy_http::endpoint::SharedEndpointResolver;
    use aws_smithy_runtime::client::connections::adapter::DynConnectorAdapter;
    use aws_smithy_runtime::client::orchestrator::endpoints::DefaultEndpointResolver;
    use aws_smithy_runtime_api::client::interceptors::error::ContextAttachedError;
    use aws_smithy_runtime_api::client::interceptors::{
        Interceptor, InterceptorContext, Interceptors,
    };
    use aws_smithy_runtime_api::client::orchestrator::{
        BoxError, ConfigBagAccessors, Connection, HttpRequest, HttpResponse, TraceProbe,
    };
    use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
    use aws_smithy_runtime_api::config_bag::ConfigBag;
    use aws_types::region::SigningRegion;
    use aws_types::SigningService;
    use std::sync::Arc;

    pub struct ManualServiceRuntimePlugin {
        pub connector: DynConnector,
        pub endpoint_resolver: SharedEndpointResolver<Params>,
        pub credentials_cache: SharedCredentialsCache,
    }

    impl RuntimePlugin for ManualServiceRuntimePlugin {
        fn configure(&self, cfg: &mut ConfigBag) -> Result<(), BoxError> {
            let identity_resolvers =
                aws_smithy_runtime_api::client::identity::IdentityResolvers::builder()
                    .identity_resolver(
                        aws_runtime::auth::sigv4::SCHEME_ID,
                        aws_runtime::identity::credentials::CredentialsIdentityResolver::new(
                            self.credentials_cache.clone(),
                        ),
                    )
                    .identity_resolver(
                        "anonymous",
                        aws_smithy_runtime_api::client::identity::AnonymousIdentityResolver::new(),
                    )
                    .build();
            cfg.set_identity_resolvers(identity_resolvers);

            let http_auth_schemes =
                aws_smithy_runtime_api::client::orchestrator::HttpAuthSchemes::builder()
                    .auth_scheme(
                        aws_runtime::auth::sigv4::SCHEME_ID,
                        aws_runtime::auth::sigv4::SigV4HttpAuthScheme::new(),
                    )
                    .build();
            cfg.set_http_auth_schemes(http_auth_schemes);

            cfg.set_auth_option_resolver(
                aws_smithy_runtime_api::client::auth::option_resolver::AuthOptionListResolver::new(
                    Vec::new(),
                ),
            );

            cfg.set_endpoint_resolver(DefaultEndpointResolver::new(self.endpoint_resolver.clone()));

            let params_builder = aws_sdk_s3::endpoint::Params::builder()
                .set_region(Some("us-east-1".to_owned()))
                .set_endpoint(Some("https://s3.us-east-1.amazonaws.com/".to_owned()));
            cfg.put(params_builder);

            cfg.set_retry_strategy(
                aws_smithy_runtime_api::client::retries::NeverRetryStrategy::new(),
            );

            let connection: Box<dyn Connection> =
                Box::new(DynConnectorAdapter::new(self.connector.clone()));
            cfg.set_connection(connection);

            cfg.set_trace_probe({
                #[derive(Debug)]
                struct StubTraceProbe;
                impl TraceProbe for StubTraceProbe {
                    fn dispatch_events(&self) {
                        // no-op
                    }
                }
                StubTraceProbe
            });

            cfg.put(SigningService::from_static("s3"));
            cfg.put(SigningRegion::from(Region::from_static("us-east-1")));

            cfg.put(ApiMetadata::new("unused", "unused"));
            cfg.put(AwsUserAgent::for_tests()); // Override the user agent with the test UA
            cfg.get::<Interceptors<HttpRequest, HttpResponse>>()
                .expect("interceptors set")
                .register_client_interceptor(Arc::new(UserAgentInterceptor::new()) as _)
                .register_client_interceptor(Arc::new(RecursionDetectionInterceptor::new()) as _);
            Ok(())
        }
    }

    // This is a temporary operation runtime plugin until <Operation>EndpointParamsInterceptor and
    // <Operation>EndpointParamsFinalizerInterceptor have been fully implemented, in which case
    // `.with_operation_plugin(ManualOperationRuntimePlugin)` can be removed.
    pub struct ManualOperationRuntimePlugin;

    impl RuntimePlugin for ManualOperationRuntimePlugin {
        fn configure(&self, cfg: &mut ConfigBag) -> Result<(), BoxError> {
            #[derive(Debug)]
            struct ListObjectsV2EndpointParamsInterceptor;
            impl Interceptor<HttpRequest, HttpResponse> for ListObjectsV2EndpointParamsInterceptor {
                fn read_before_execution(
                    &self,
                    context: &InterceptorContext<HttpRequest, HttpResponse>,
                    cfg: &mut ConfigBag,
                ) -> Result<(), BoxError> {
                    let input = context.input()?;
                    let input = input
                        .downcast_ref::<ListObjectsV2Input>()
                        .ok_or_else(|| "failed to downcast to ListObjectsV2Input")?;
                    let mut params_builder = cfg
                        .get::<aws_sdk_s3::endpoint::ParamsBuilder>()
                        .ok_or_else(|| "missing endpoint params builder")?
                        .clone();
                    params_builder = params_builder.set_bucket(input.bucket.clone());
                    cfg.put(params_builder);

                    Ok(())
                }
            }

            #[derive(Debug)]
            struct ListObjectsV2EndpointParamsFinalizerInterceptor;
            impl Interceptor<HttpRequest, HttpResponse> for ListObjectsV2EndpointParamsFinalizerInterceptor {
                fn read_before_execution(
                    &self,
                    _context: &InterceptorContext<HttpRequest, HttpResponse>,
                    cfg: &mut ConfigBag,
                ) -> Result<(), BoxError> {
                    let params_builder = cfg
                        .get::<aws_sdk_s3::endpoint::ParamsBuilder>()
                        .ok_or_else(|| "missing endpoint params builder")?
                        .clone();
                    let params = params_builder.build().map_err(|err| {
                        ContextAttachedError::new("endpoint params could not be built", err)
                    })?;
                    cfg.put(
                        aws_smithy_runtime_api::client::orchestrator::EndpointResolverParams::new(
                            params,
                        ),
                    );

                    Ok(())
                }
            }

            cfg.get::<Interceptors<HttpRequest, HttpResponse>>()
                .expect("interceptors set")
                .register_operation_interceptor(
                    Arc::new(ListObjectsV2EndpointParamsInterceptor) as _
                )
                .register_operation_interceptor(Arc::new(
                    ListObjectsV2EndpointParamsFinalizerInterceptor,
                ) as _);
            Ok(())
        }
    }
}

criterion_group!(benches, middleware_bench, orchestrator_bench);
criterion_main!(benches);
+0 −4
Original line number Diff line number Diff line
@@ -7,17 +7,13 @@ use aws_http::user_agent::AwsUserAgent;
use aws_runtime::invocation_id::InvocationId;
use aws_sdk_s3::config::{Credentials, Region};
use aws_sdk_s3::endpoint::Params;

use aws_sdk_s3::Client;

use aws_smithy_client::dvr;
use aws_smithy_client::dvr::MediaType;
use aws_smithy_client::erase::DynConnector;

use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, RequestTime};
use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
use aws_smithy_runtime_api::config_bag::ConfigBag;

use std::time::{Duration, SystemTime, UNIX_EPOCH};

const LIST_BUCKETS_PATH: &str = "test-data/list-objects-v2.json";
Loading