Unverified Commit beedd2c7 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Refactor interceptor phases to improve optionality inside interceptors (#2670)

## Motivation and Context
This PR:
- Deletes `TraceProbe`
- Replaces the orchestrator's `Phase` helper with a couple of macros
- Introduces the concept of phases into `InterceptorContext` so that
input/output/error/request/response accessors don't need option wrappers
- Adds `TypeErasedError` so that `orchestrator::Error` can implement
`Error`
- Rewinds the interceptor context in the retry loop

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent c75807ce
Loading
Loading
Loading
Loading
+13 −10
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::error::BoxError;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
@@ -37,10 +38,10 @@ impl Default for InvocationIdInterceptor {
impl Interceptor for InvocationIdInterceptor {
    fn modify_before_retry_loop(
        &self,
        context: &mut InterceptorContext,
        context: &mut InterceptorContext<BeforeTransmit>,
        _cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let headers = context.request_mut()?.headers_mut();
        let headers = context.request_mut().headers_mut();
        let id = _cfg.get::<InvocationId>().unwrap_or(&self.id);
        headers.append(AMZ_SDK_INVOCATION_ID, id.0.clone());
        Ok(())
@@ -72,24 +73,26 @@ impl InvocationId {
mod tests {
    use crate::invocation_id::InvocationIdInterceptor;
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
    use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
    use aws_smithy_runtime_api::config_bag::ConfigBag;
    use aws_smithy_runtime_api::type_erasure::TypedBox;
    use http::HeaderValue;

    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a HeaderValue {
        context
            .request()
            .unwrap()
            .headers()
            .get(header_name)
            .unwrap()
    fn expect_header<'a>(
        context: &'a InterceptorContext<BeforeTransmit>,
        header_name: &str,
    ) -> &'a HeaderValue {
        context.request().headers().get(header_name).unwrap()
    }

    #[test]
    fn test_id_is_generated_and_set() {
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
            .into_serialization_phase();
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
        let _ = context.take_input();
        let mut context = context.into_before_transmit_phase();

        let mut config = ConfigBag::base();
        let interceptor = InvocationIdInterceptor::new();
+8 −4
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_types::os_shim_internal::Env;
@@ -39,10 +40,10 @@ impl RecursionDetectionInterceptor {
impl Interceptor for RecursionDetectionInterceptor {
    fn modify_before_signing(
        &self,
        context: &mut InterceptorContext,
        context: &mut InterceptorContext<BeforeTransmit>,
        _cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let request = context.request_mut()?;
        let request = context.request_mut();
        if request.headers().contains_key(TRACE_ID_HEADER) {
            return Ok(());
        }
@@ -145,14 +146,17 @@ mod tests {
            request = request.header(name, value);
        }
        let request = request.body(SdkBody::empty()).expect("must be valid");
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
            .into_serialization_phase();
        context.set_request(request);
        let _ = context.take_input();
        let mut context = context.into_before_transmit_phase();
        let mut config = ConfigBag::base();

        RecursionDetectionInterceptor { env }
            .modify_before_signing(&mut context, &mut config)
            .expect("interceptor must succeed");
        let mutated_request = context.request().expect("request is still set");
        let mutated_request = context.request();
        for name in mutated_request.headers().keys() {
            assert_eq!(
                mutated_request.headers().get_all(name).iter().count(),
+19 −13
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@
 */

use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::error::BoxError;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
@@ -72,7 +73,7 @@ fn header_values(
impl Interceptor for UserAgentInterceptor {
    fn modify_before_signing(
        &self,
        context: &mut InterceptorContext,
        context: &mut InterceptorContext<BeforeTransmit>,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let api_metadata = cfg
@@ -95,7 +96,7 @@ impl Interceptor for UserAgentInterceptor {
                Cow::Owned(ua)
            });

        let headers = context.request_mut()?.headers_mut();
        let headers = context.request_mut().headers_mut();
        let (user_agent, x_amz_user_agent) = header_values(&ua)?;
        headers.append(USER_AGENT, user_agent);
        headers.append(X_AMZ_USER_AGENT, x_amz_user_agent);
@@ -112,10 +113,12 @@ mod tests {
    use aws_smithy_runtime_api::type_erasure::TypedBox;
    use aws_smithy_types::error::display::DisplayErrorContext;

    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
    fn expect_header<'a>(
        context: &'a InterceptorContext<BeforeTransmit>,
        header_name: &str,
    ) -> &'a str {
        context
            .request()
            .unwrap()
            .headers()
            .get(header_name)
            .unwrap()
@@ -123,10 +126,17 @@ mod tests {
            .unwrap()
    }

    fn context() -> InterceptorContext<BeforeTransmit> {
        let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
            .into_serialization_phase();
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
        let _ = context.take_input();
        context.into_before_transmit_phase()
    }

    #[test]
    fn test_overridden_ua() {
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
        let mut context = context();

        let mut config = ConfigBag::base();
        config.put(AwsUserAgent::for_tests());
@@ -149,8 +159,7 @@ mod tests {

    #[test]
    fn test_default_ua() {
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
        let mut context = context();

        let api_metadata = ApiMetadata::new("some-service", "some-version");
        let mut config = ConfigBag::base();
@@ -178,8 +187,7 @@ mod tests {

    #[test]
    fn test_app_name() {
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
        let mut context = context();

        let api_metadata = ApiMetadata::new("some-service", "some-version");
        let mut config = ConfigBag::base();
@@ -207,9 +215,7 @@ mod tests {

    #[test]
    fn test_api_metadata_missing() {
        let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
        context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());

        let mut context = context();
        let mut config = ConfigBag::base();

        let interceptor = UserAgentInterceptor::new();
+1 −1
Original line number Diff line number Diff line
@@ -74,7 +74,7 @@ class OperationRetryClassifiersFeature(
        "RetryClassifiers" to smithyRuntimeApi.resolve("client::retries::RetryClassifiers"),
        "OperationError" to codegenContext.symbolProvider.symbolForOperationError(operation),
        "SdkError" to RuntimeType.smithyHttp(runtimeConfig).resolve("result::SdkError"),
        "ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedBox"),
        "ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedError"),
    )

    override fun section(section: OperationRuntimePluginSection) = when (section) {
+3 −3
Original line number Diff line number Diff line
@@ -49,6 +49,7 @@ class EndpointParamsInterceptorGenerator(
            "HttpResponse" to orchestrator.resolve("HttpResponse"),
            "Interceptor" to interceptors.resolve("Interceptor"),
            "InterceptorContext" to interceptors.resolve("InterceptorContext"),
            "BeforeSerializationPhase" to interceptors.resolve("context::phase::BeforeSerialization"),
            "InterceptorError" to interceptors.resolve("error::InterceptorError"),
            "Params" to endpointTypesGenerator.paramsStruct(),
        )
@@ -66,11 +67,10 @@ class EndpointParamsInterceptorGenerator(
            impl #{Interceptor} for $interceptorName {
                fn read_before_execution(
                    &self,
                    context: &#{InterceptorContext},
                    context: &#{InterceptorContext}<#{BeforeSerializationPhase}>,
                    cfg: &mut #{ConfigBag},
                ) -> Result<(), #{BoxError}> {
                    let _input = context.input()?;
                    let _input = _input
                    let _input = context.input()
                        .downcast_ref::<${operationInput.name}>()
                        .ok_or("failed to downcast to ${operationInput.name}")?;

Loading