From a9e22ce17a8e6c4b2cfbcd0e5a3b5e6d74bb3c05 Mon Sep 17 00:00:00 2001 From: Zelda Hessler Date: Thu, 18 May 2023 09:59:36 -0500 Subject: [PATCH] Fix: orchestrator flow (#2699) ## Motivation and Context Necessary before we can implement retries. ## Description I noticed that we weren't handling the flow quite right when errors occurred. This PR fixes that and adds interceptor-based tests to ensure things are working right. I still think we could use more tests but the PR is already quite large. ## Testing This PR contains tests ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --- .../src/apigateway_interceptors.rs | 7 +- .../aws-runtime/src/invocation_id.rs | 25 +- .../aws-runtime/src/recursion_detection.rs | 17 +- .../aws-runtime/src/request_info.rs | 22 +- .../aws-runtime/src/user_agent.rs | 34 +- .../aws-sdk-s3/tests/interceptors.rs | 10 +- .../aws-sdk-s3/tests/util.rs | 6 +- .../EndpointParamsInterceptorGenerator.kt | 7 +- .../protocol/ResponseDeserializerGenerator.kt | 10 +- .../src/client/interceptors.rs | 491 ++++++----- .../src/client/interceptors/context.rs | 784 ++++++------------ .../src/client/interceptors/context/phase.rs | 53 ++ .../client/interceptors/context/wrappers.rs | 253 ++++++ .../src/client/interceptors/error.rs | 5 + .../src/client/orchestrator.rs | 15 +- .../src/client/orchestrator/error.rs | 117 +++ .../src/client/retries.rs | 9 +- .../src/type_erasure.rs | 34 +- .../src/client/interceptor.rs | 9 +- .../src/client/orchestrator.rs | 621 ++++++++++---- .../src/client/orchestrator/auth.rs | 109 +-- .../src/client/orchestrator/endpoints.rs | 3 +- .../interceptors/request_attempts.rs | 9 +- .../interceptors/service_clock_skew.rs | 9 +- .../client/retries/strategy/fixed_delay.rs | 3 +- .../src/client/retries/strategy/never.rs | 3 +- .../src/client/test_util/deserializer.rs | 13 +- 27 files changed, 1616 insertions(+), 1062 deletions(-) create mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/phase.rs create mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/wrappers.rs create mode 100644 rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs diff --git a/aws/rust-runtime/aws-inlineable/src/apigateway_interceptors.rs b/aws/rust-runtime/aws-inlineable/src/apigateway_interceptors.rs index 0ba3c1fa1..b6e56890d 100644 --- a/aws/rust-runtime/aws-inlineable/src/apigateway_interceptors.rs +++ b/aws/rust-runtime/aws-inlineable/src/apigateway_interceptors.rs @@ -5,8 +5,9 @@ #![allow(dead_code)] -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use http::header::ACCEPT; use http::HeaderValue; @@ -18,7 +19,7 @@ pub(crate) struct AcceptHeaderInterceptor; impl Interceptor for AcceptHeaderInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { context diff --git a/aws/rust-runtime/aws-runtime/src/invocation_id.rs b/aws/rust-runtime/aws-runtime/src/invocation_id.rs index 98d5405ec..b5234264e 100644 --- a/aws/rust-runtime/aws-runtime/src/invocation_id.rs +++ b/aws/rust-runtime/aws-runtime/src/invocation_id.rs @@ -3,9 +3,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::error::BoxError; -use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use http::{HeaderName, HeaderValue}; use uuid::Uuid; @@ -38,7 +39,7 @@ impl Default for InvocationIdInterceptor { impl Interceptor for InvocationIdInterceptor { fn modify_before_retry_loop( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let headers = context.request_mut().headers_mut(); @@ -73,31 +74,31 @@ impl InvocationId { mod tests { use crate::invocation_id::InvocationIdInterceptor; use aws_smithy_http::body::SdkBody; - use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_runtime_api::type_erasure::TypedBox; use http::HeaderValue; - fn expect_header<'a>( - context: &'a InterceptorContext, - header_name: &str, - ) -> &'a HeaderValue { + fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a HeaderValue { context.request().headers().get(header_name).unwrap() } #[test] fn test_id_is_generated_and_set() { - let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) - .into_serialization_phase(); + let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + context.enter_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); let _ = context.take_input(); - let mut context = context.into_before_transmit_phase(); + context.enter_before_transmit_phase(); let mut config = ConfigBag::base(); let interceptor = InvocationIdInterceptor::new(); + let mut ctx = Into::into(&mut context); interceptor - .modify_before_retry_loop(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) + .unwrap(); + interceptor + .modify_before_retry_loop(&mut ctx, &mut config) .unwrap(); let header = expect_header(&context, "amz-sdk-invocation-id"); diff --git a/aws/rust-runtime/aws-runtime/src/recursion_detection.rs b/aws/rust-runtime/aws-runtime/src/recursion_detection.rs index 38d2cf272..4c0d9f10b 100644 --- a/aws/rust-runtime/aws-runtime/src/recursion_detection.rs +++ b/aws/rust-runtime/aws-runtime/src/recursion_detection.rs @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_types::os_shim_internal::Env; use http::HeaderValue; @@ -40,7 +41,7 @@ impl RecursionDetectionInterceptor { impl Interceptor for RecursionDetectionInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let request = context.request_mut(); @@ -73,6 +74,7 @@ mod tests { use super::*; use aws_smithy_http::body::SdkBody; use aws_smithy_protocol_test::{assert_ok, validate_headers}; + use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::type_erasure::TypedBox; use aws_types::os_shim_internal::Env; use http::HeaderValue; @@ -146,15 +148,16 @@ mod tests { request = request.header(name, value); } let request = request.body(SdkBody::empty()).expect("must be valid"); - let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) - .into_serialization_phase(); + let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + context.enter_serialization_phase(); context.set_request(request); let _ = context.take_input(); - let mut context = context.into_before_transmit_phase(); + context.enter_before_transmit_phase(); let mut config = ConfigBag::base(); + let mut ctx = Into::into(&mut context); RecursionDetectionInterceptor { env } - .modify_before_signing(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) .expect("interceptor must succeed"); let mutated_request = context.request(); for name in mutated_request.headers().keys() { diff --git a/aws/rust-runtime/aws-runtime/src/request_info.rs b/aws/rust-runtime/aws-runtime/src/request_info.rs index 05a80e220..3896654ef 100644 --- a/aws/rust-runtime/aws-runtime/src/request_info.rs +++ b/aws/rust-runtime/aws-runtime/src/request_info.rs @@ -4,8 +4,9 @@ */ use aws_smithy_runtime::client::orchestrator::interceptors::{RequestAttempts, ServiceClockSkew}; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_types::date_time::Format; use aws_smithy_types::retry::RetryConfig; @@ -79,7 +80,7 @@ impl RequestInfoInterceptor { impl Interceptor for RequestInfoInterceptor { fn modify_before_transmit( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let mut pairs = RequestPairs::new(); @@ -156,7 +157,6 @@ mod tests { use crate::request_info::RequestPairs; use aws_smithy_http::body::SdkBody; use aws_smithy_runtime::client::orchestrator::interceptors::RequestAttempts; - use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_runtime_api::type_erasure::TypedBox; @@ -165,10 +165,7 @@ mod tests { use http::HeaderValue; use std::time::Duration; - fn expect_header<'a>( - context: &'a InterceptorContext, - header_name: &str, - ) -> &'a str { + fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str { context .request() .headers() @@ -180,8 +177,8 @@ mod tests { #[test] fn test_request_pairs_for_initial_attempt() { - let context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()); - let mut context = context.into_serialization_phase(); + let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + context.enter_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); let mut config = ConfigBag::base(); @@ -194,10 +191,11 @@ mod tests { config.put(RequestAttempts::new()); let _ = context.take_input(); - let mut context = context.into_before_transmit_phase(); + context.enter_before_transmit_phase(); let interceptor = RequestInfoInterceptor::new(); + let mut ctx = (&mut context).into(); interceptor - .modify_before_transmit(&mut context, &mut config) + .modify_before_transmit(&mut ctx, &mut config) .unwrap(); assert_eq!( diff --git a/aws/rust-runtime/aws-runtime/src/user_agent.rs b/aws/rust-runtime/aws-runtime/src/user_agent.rs index cdd3003da..326cc5cca 100644 --- a/aws/rust-runtime/aws-runtime/src/user_agent.rs +++ b/aws/rust-runtime/aws-runtime/src/user_agent.rs @@ -4,9 +4,10 @@ */ use aws_http::user_agent::{ApiMetadata, AwsUserAgent}; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::error::BoxError; -use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_types::app_name::AppName; use aws_types::os_shim_internal::Env; @@ -73,7 +74,7 @@ fn header_values( impl Interceptor for UserAgentInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let api_metadata = cfg @@ -113,10 +114,7 @@ mod tests { use aws_smithy_runtime_api::type_erasure::TypedBox; use aws_smithy_types::error::display::DisplayErrorContext; - fn expect_header<'a>( - context: &'a InterceptorContext, - header_name: &str, - ) -> &'a str { + fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str { context .request() .headers() @@ -126,12 +124,13 @@ mod tests { .unwrap() } - fn context() -> InterceptorContext { - let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) - .into_serialization_phase(); + fn context() -> InterceptorContext { + let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + context.enter_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); let _ = context.take_input(); - context.into_before_transmit_phase() + context.enter_before_transmit_phase(); + context } #[test] @@ -143,8 +142,9 @@ mod tests { config.put(ApiMetadata::new("unused", "unused")); let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); interceptor - .modify_before_signing(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) .unwrap(); let header = expect_header(&context, "user-agent"); @@ -166,8 +166,9 @@ mod tests { config.put(api_metadata.clone()); let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); interceptor - .modify_before_signing(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) .unwrap(); let expected_ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata); @@ -195,8 +196,9 @@ mod tests { config.put(AppName::new("my_awesome_app").unwrap()); let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); interceptor - .modify_before_signing(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) .unwrap(); let app_value = "app/my_awesome_app"; @@ -219,11 +221,13 @@ mod tests { let mut config = ConfigBag::base(); let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); + let error = format!( "{}", DisplayErrorContext( &*interceptor - .modify_before_signing(&mut context, &mut config) + .modify_before_signing(&mut ctx, &mut config) .expect_err("it should error") ) ); diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs b/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs index f33ef94ff..1f061f7f9 100644 --- a/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs +++ b/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs @@ -10,8 +10,10 @@ use aws_sdk_s3::Client; use aws_smithy_client::dvr; use aws_smithy_client::dvr::MediaType; use aws_smithy_client::erase::DynConnector; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, Interceptor, +}; use aws_smithy_runtime_api::client::orchestrator::ConfigBagAccessors; use aws_smithy_runtime_api::client::orchestrator::RequestTime; use aws_smithy_runtime_api::config_bag::ConfigBag; @@ -59,7 +61,7 @@ struct RequestTimeResetInterceptor; impl Interceptor for RequestTimeResetInterceptor { fn modify_before_signing( &self, - _context: &mut InterceptorContext, + _context: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), aws_smithy_runtime_api::client::interceptors::BoxError> { cfg.set_request_time(RequestTime::new(UNIX_EPOCH)); @@ -73,7 +75,7 @@ struct RequestTimeAdvanceInterceptor(Duration); impl Interceptor for RequestTimeAdvanceInterceptor { fn modify_before_signing( &self, - _context: &mut InterceptorContext, + _context: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), aws_smithy_runtime_api::client::interceptors::BoxError> { let request_time = cfg.request_time().unwrap(); diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/tests/util.rs b/aws/sra-test/integration-tests/aws-sdk-s3/tests/util.rs index af3d53094..4f2997f5c 100644 --- a/aws/sra-test/integration-tests/aws-sdk-s3/tests/util.rs +++ b/aws/sra-test/integration-tests/aws-sdk-s3/tests/util.rs @@ -5,9 +5,9 @@ use aws_http::user_agent::AwsUserAgent; use aws_runtime::invocation_id::InvocationId; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; +use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output}; use aws_smithy_runtime_api::client::interceptors::{ - Interceptor, InterceptorContext, InterceptorRegistrar, + BeforeTransmitInterceptorContextMut, Interceptor, InterceptorRegistrar, }; use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, RequestTime}; use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; @@ -39,7 +39,7 @@ pub struct TestUserAgentInterceptor; impl Interceptor for TestUserAgentInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), aws_smithy_runtime_api::client::interceptors::BoxError> { let headers = context.request_mut().headers_mut(); diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt index ee1975d65..04eb32ca5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt @@ -49,7 +49,10 @@ class EndpointParamsInterceptorGenerator( "HttpResponse" to orchestrator.resolve("HttpResponse"), "Interceptor" to interceptors.resolve("Interceptor"), "InterceptorContext" to interceptors.resolve("InterceptorContext"), - "BeforeSerializationPhase" to interceptors.resolve("context::phase::BeforeSerialization"), + "BeforeSerializationInterceptorContextRef" to interceptors.resolve("context::wrappers::BeforeSerializationInterceptorContextRef"), + "Input" to interceptors.resolve("context::Input"), + "Output" to interceptors.resolve("context::Output"), + "Error" to interceptors.resolve("context::Error"), "InterceptorError" to interceptors.resolve("error::InterceptorError"), "Params" to endpointTypesGenerator.paramsStruct(), ) @@ -67,7 +70,7 @@ class EndpointParamsInterceptorGenerator( impl #{Interceptor} for $interceptorName { fn read_before_execution( &self, - context: &#{InterceptorContext}<#{BeforeSerializationPhase}>, + context: &#{BeforeSerializationInterceptorContextRef}<'_, #{Input}, #{Output}, #{Error}>, cfg: &mut #{ConfigBag}, ) -> Result<(), #{BoxError}> { let _input = context.input() diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt index 1df6e33c6..83bc4dff5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt @@ -41,6 +41,7 @@ class ResponseDeserializerGenerator( "Instrument" to CargoDependency.Tracing.toType().resolve("Instrument"), "Output" to interceptorContext.resolve("Output"), "OutputOrError" to interceptorContext.resolve("OutputOrError"), + "OrchestratorError" to orchestrator.resolve("OrchestratorError"), "ResponseDeserializer" to orchestrator.resolve("ResponseDeserializer"), "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "SdkError" to RuntimeType.sdkError(runtimeConfig), @@ -98,7 +99,7 @@ class ResponseDeserializerGenerator( if !response.status().is_success() && response.status().as_u16() != $successCode { return None; } - Some(#{type_erase_result}(#{parse_streaming_response}(response))) + Some(#{type_erase_result}(#{parse_streaming_response}(response)).into()) } """, *codegenScope, @@ -117,7 +118,7 @@ class ResponseDeserializerGenerator( """ // For streaming operations, we only hit this case if its an error let body = response.body().bytes().expect("body loaded"); - #{type_erase_result}(#{parse_error}(response.status().as_u16(), response.headers(), body)) + #{type_erase_result}(#{parse_error}(response.status().as_u16(), response.headers(), body)).into() """, *codegenScope, "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), @@ -140,7 +141,7 @@ class ResponseDeserializerGenerator( } else { #{parse_response}(status, headers, body) }; - #{type_erase_result}(parse_result) + #{type_erase_result}(parse_result).into() """, *codegenScope, "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), @@ -154,13 +155,14 @@ class ResponseDeserializerGenerator( private fun typeEraseResult(): RuntimeType = ProtocolFunctions.crossOperationFn("type_erase_result") { fnName -> rustTemplate( """ - pub(crate) fn $fnName(result: Result) -> Result<#{Output}, #{Error}> + pub(crate) fn $fnName(result: Result) -> Result<#{Output}, #{OrchestratorError}<#{Error}>> where O: std::fmt::Debug + Send + Sync + 'static, E: std::error::Error + std::fmt::Debug + Send + Sync + 'static, { result.map(|output| #{TypedBox}::new(output).erase()) .map_err(|error| #{TypedBox}::new(error).erase_error()) + .map_err(Into::into) } """, *codegenScope, diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs index ce4195d68..ec322badf 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs @@ -6,12 +6,21 @@ pub mod context; pub mod error; -use crate::client::interceptors::context::phase::{ - AfterDeserialization, BeforeDeserialization, BeforeSerialization, BeforeTransmit, +use crate::client::interceptors::context::wrappers::{ + FinalizerInterceptorContextMut, FinalizerInterceptorContextRef, }; use crate::config_bag::ConfigBag; use aws_smithy_types::error::display::DisplayErrorContext; -pub use context::InterceptorContext; +pub use context::{ + wrappers::{ + AfterDeserializationInterceptorContextMut, AfterDeserializationInterceptorContextRef, + BeforeDeserializationInterceptorContextMut, BeforeDeserializationInterceptorContextRef, + BeforeSerializationInterceptorContextMut, BeforeSerializationInterceptorContextRef, + BeforeTransmitInterceptorContextMut, BeforeTransmitInterceptorContextRef, + }, + InterceptorContext, +}; +use context::{Error, Input, Output}; pub use error::{BoxError, InterceptorError}; use std::ops::Deref; use std::sync::Arc; @@ -19,11 +28,7 @@ use std::sync::Arc; macro_rules! interceptor_trait_fn { ($name:ident, $phase:ident, $docs:tt) => { #[doc = $docs] - fn $name( - &self, - context: &InterceptorContext<$phase>, - cfg: &mut ConfigBag, - ) -> Result<(), BoxError> { + fn $name(&self, context: &$phase<'_>, cfg: &mut ConfigBag) -> Result<(), BoxError> { let _ctx = context; let _cfg = cfg; Ok(()) @@ -31,11 +36,7 @@ macro_rules! interceptor_trait_fn { }; (mut $name:ident, $phase:ident, $docs:tt) => { #[doc = $docs] - fn $name( - &self, - context: &mut InterceptorContext<$phase>, - cfg: &mut ConfigBag, - ) -> Result<(), BoxError> { + fn $name(&self, context: &mut $phase<'_>, cfg: &mut ConfigBag) -> Result<(), BoxError> { let _ctx = context; let _cfg = cfg; Ok(()) @@ -56,7 +57,7 @@ macro_rules! interceptor_trait_fn { pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_execution, - BeforeSerialization, + BeforeSerializationInterceptorContextRef, " A hook called at the start of an execution, before the SDK does anything else. @@ -80,7 +81,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_serialization, - BeforeSerialization, + BeforeSerializationInterceptorContextMut, " A hook called before the input message is marshalled into a transport message. @@ -108,7 +109,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_serialization, - BeforeSerialization, + BeforeSerializationInterceptorContextRef, " A hook called before the input message is marshalled into a transport @@ -130,7 +131,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_serialization, - BeforeTransmit, + BeforeTransmitInterceptorContextRef, " /// A hook called after the input message is marshalled into /// a transport message. @@ -152,7 +153,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_retry_loop, - BeforeTransmit, + BeforeTransmitInterceptorContextMut, " A hook called before the retry loop is entered. This method has the ability to modify and return a new transport request @@ -174,7 +175,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_attempt, - BeforeTransmit, + BeforeTransmitInterceptorContextRef, " A hook called before each attempt at sending the transmission request message to the service. @@ -201,7 +202,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_signing, - BeforeTransmit, + BeforeTransmitInterceptorContextMut, " A hook called before the transport request message is signed. This method has the ability to modify and return a new transport @@ -233,7 +234,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_signing, - BeforeTransmit, + BeforeTransmitInterceptorContextRef, " A hook called before the transport request message is signed. @@ -257,7 +258,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_signing, - BeforeTransmit, + BeforeTransmitInterceptorContextRef, " A hook called after the transport request message is signed. @@ -281,7 +282,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_transmit, - BeforeTransmit, + BeforeTransmitInterceptorContextMut, " /// A hook called before the transport request message is sent to the /// service. This method has the ability to modify and return @@ -313,7 +314,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_transmit, - BeforeTransmit, + BeforeTransmitInterceptorContextRef, " A hook called before the transport request message is sent to the service. @@ -341,7 +342,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_transmit, - BeforeDeserialization, + BeforeDeserializationInterceptorContextRef, " A hook called after the transport request message is sent to the service and a transport response message is received. @@ -369,7 +370,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_deserialization, - BeforeDeserialization, + BeforeDeserializationInterceptorContextMut, " A hook called before the transport response message is unmarshalled. This method has the ability to modify and return a new transport @@ -401,7 +402,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_deserialization, - BeforeDeserialization, + BeforeDeserializationInterceptorContextRef, " A hook called before the transport response message is unmarshalled @@ -428,7 +429,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_deserialization, - AfterDeserialization, + AfterDeserializationInterceptorContextRef, " A hook called after the transport response message is unmarshalled. @@ -453,115 +454,127 @@ pub trait Interceptor: std::fmt::Debug { " ); - interceptor_trait_fn!( - mut modify_before_attempt_completion, - AfterDeserialization, - " - A hook called when an attempt is completed. This method has the - ability to modify and return a new output message or error - matching the currently-executing operation. - - **When:** This will **ALWAYS** be called once per attempt, except when a - failure occurs before `before_attempt`. This method may - be called multiple times in the event of retries. - - **Available Information:** The [InterceptorContext::input()], - [InterceptorContext::request()], - [InterceptorContext::response()] and - [InterceptorContext::output_or_error()] are **ALWAYS** available. In the event - of retries, the `InterceptorContext` will not include changes made - in previous attempts (e.g. by request signers or other interceptors). - - **Error Behavior:** If errors are raised by this - hook, execution will jump to `after_attempt` with - the raised error as the [InterceptorContext::output_or_error()]. - - **Return Constraints:** Any output message returned by this - hook MUST match the operation being invoked. Any error type can be - returned, replacing the response currently in the context. - " - ); - - interceptor_trait_fn!( - read_after_attempt, - AfterDeserialization, - " - A hook called when an attempt is completed. - - **When:** This will **ALWAYS** be called once per attempt, as long as - `before_attempt` has been executed. - - **Available Information:** The [InterceptorContext::input()], - [InterceptorContext::request()] and - [InterceptorContext::output_or_error()] are **ALWAYS** available. - The [InterceptorContext::response()] is available if a - response was received by the service for this attempt. - In the event of retries, the `InterceptorContext` will not include - changes made in previous attempts (e.g. by request signers or other - interceptors). - - **Error Behavior:** Errors raised by this hook will be stored - until all interceptors have had their `after_attempt` invoked. - If multiple `after_execution` methods raise errors, the latest - will be used and earlier ones will be logged and dropped. If the - retry strategy determines that the execution is retryable, - execution will then jump to `before_attempt`. Otherwise, - execution will jump to `modify_before_attempt_completion` with the - raised error as the [InterceptorContext::output_or_error()]. - " - ); - - interceptor_trait_fn!( - mut modify_before_completion, - AfterDeserialization, - " - A hook called when an execution is completed. - This method has the ability to modify and return a new - output message or error matching the currently - executing - operation. - - **When:** This will **ALWAYS** be called once per execution. - - **Available Information:** The [InterceptorContext::input()] - and [InterceptorContext::output_or_error()] are **ALWAYS** available. The - [InterceptorContext::request()] - and [InterceptorContext::response()] are available if the - execution proceeded far enough for them to be generated. - - **Error Behavior:** If errors are raised by this - hook , execution will jump to `after_attempt` with - the raised error as the [InterceptorContext::output_or_error()]. - - **Return Constraints:** Any output message returned by this - hook MUST match the operation being invoked. Any error type can be - returned , replacing the response currently in the context. - " - ); - - interceptor_trait_fn!( - read_after_execution, - AfterDeserialization, - " - A hook called when an execution is completed. + /// A hook called when an attempt is completed. This method has the + /// ability to modify and return a new output message or error + /// matching the currently-executing operation. + /// + /// **When:** This will **ALWAYS** be called once per attempt, except when a + /// failure occurs before `before_attempt`. This method may + /// be called multiple times in the event of retries. + /// + /// **Available Information:** The [InterceptorContext::input()], + /// [InterceptorContext::request()], + /// [InterceptorContext::response()] and + /// [InterceptorContext::output_or_error()] are **ALWAYS** available. In the event + /// of retries, the `InterceptorContext` will not include changes made + /// in previous attempts (e.g. by request signers or other interceptors). + /// + /// **Error Behavior:** If errors are raised by this + /// hook, execution will jump to `after_attempt` with + /// the raised error as the [InterceptorContext::output_or_error()]. + /// + /// **Return Constraints:** Any output message returned by this + /// hook MUST match the operation being invoked. Any error type can be + /// returned, replacing the response currently in the context. + fn modify_before_attempt_completion( + &self, + context: &mut FinalizerInterceptorContextMut<'_>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let _ctx = context; + let _cfg = cfg; + Ok(()) + } - **When:** This will **ALWAYS** be called once per execution. The duration - between invocation of this hook and `before_execution` is very - close to the full duration of the execution. + /// A hook called when an attempt is completed. + /// + /// **When:** This will **ALWAYS** be called once per attempt, as long as + /// `before_attempt` has been executed. + /// + /// **Available Information:** The [InterceptorContext::input()], + /// [InterceptorContext::request()] and + /// [InterceptorContext::output_or_error()] are **ALWAYS** available. + /// The [InterceptorContext::response()] is available if a + /// response was received by the service for this attempt. + /// In the event of retries, the `InterceptorContext` will not include + /// changes made in previous attempts (e.g. by request signers or other + /// interceptors). + /// + /// **Error Behavior:** Errors raised by this hook will be stored + /// until all interceptors have had their `after_attempt` invoked. + /// If multiple `after_execution` methods raise errors, the latest + /// will be used and earlier ones will be logged and dropped. If the + /// retry strategy determines that the execution is retryable, + /// execution will then jump to `before_attempt`. Otherwise, + /// execution will jump to `modify_before_attempt_completion` with the + /// raised error as the [InterceptorContext::output_or_error()]. + fn read_after_attempt( + &self, + context: &FinalizerInterceptorContextRef<'_>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let _ctx = context; + let _cfg = cfg; + Ok(()) + } - **Available Information:** The [InterceptorContext::input()] - and [InterceptorContext::output_or_error()] are **ALWAYS** available. The - [InterceptorContext::request()] and - [InterceptorContext::response()] are available if the - execution proceeded far enough for them to be generated. + /// A hook called when an execution is completed. + /// This method has the ability to modify and return a new + /// output message or error matching the currently - executing + /// operation. + /// + /// **When:** This will **ALWAYS** be called once per execution. + /// + /// **Available Information:** The [InterceptorContext::input()] + /// and [InterceptorContext::output_or_error()] are **ALWAYS** available. The + /// [InterceptorContext::request()] + /// and [InterceptorContext::response()] are available if the + /// execution proceeded far enough for them to be generated. + /// + /// **Error Behavior:** If errors are raised by this + /// hook , execution will jump to `after_attempt` with + /// the raised error as the [InterceptorContext::output_or_error()]. + /// + /// **Return Constraints:** Any output message returned by this + /// hook MUST match the operation being invoked. Any error type can be + /// returned , replacing the response currently in the context. + fn modify_before_completion( + &self, + context: &mut FinalizerInterceptorContextMut<'_>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let _ctx = context; + let _cfg = cfg; + Ok(()) + } - **Error Behavior:** Errors raised by this hook will be stored - until all interceptors have had their `after_execution` invoked. - The error will then be treated as the - [InterceptorContext::output_or_error()] to the customer. If multiple - `after_execution` methods raise errors , the latest will be - used and earlier ones will be logged and dropped. - " - ); + /// A hook called when an execution is completed. + /// + /// **When:** This will **ALWAYS** be called once per execution. The duration + /// between invocation of this hook and `before_execution` is very + /// close to the full duration of the execution. + /// + /// **Available Information:** The [InterceptorContext::input()] + /// and [InterceptorContext::output_or_error()] are **ALWAYS** available. The + /// [InterceptorContext::request()] and + /// [InterceptorContext::response()] are available if the + /// execution proceeded far enough for them to be generated. + /// + /// **Error Behavior:** Errors raised by this hook will be stored + /// until all interceptors have had their `after_execution` invoked. + /// The error will then be treated as the + /// [InterceptorContext::output_or_error()] to the customer. If multiple + /// `after_execution` methods raise errors , the latest will be + /// used and earlier ones will be logged and dropped. + fn read_after_execution( + &self, + context: &FinalizerInterceptorContextRef<'_>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let _ctx = context; + let _cfg = cfg; + Ok(()) + } } /// Interceptor wrapper that may be shared @@ -619,40 +632,42 @@ pub struct Interceptors { } macro_rules! interceptor_impl_fn { - (context, $name:ident, $phase:ident) => { - interceptor_impl_fn!(context, $name, $name, $phase); - }; - (mut context, $name:ident, $phase:ident) => { - interceptor_impl_fn!(mut context, $name, $name, $phase); - }; - (context, $outer_name:ident, $inner_name:ident, $phase:ident) => { - interceptor_impl_fn!( - $outer_name, - $inner_name(context: &InterceptorContext<$phase>) - ); - }; - (mut context, $outer_name:ident, $inner_name:ident, $phase:ident) => { - interceptor_impl_fn!( - $outer_name, - $inner_name(context: &mut InterceptorContext<$phase>) - ); + (mut $interceptor:ident) => { + pub fn $interceptor( + &self, + ctx: &mut InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let mut ctx = ctx.into(); + for interceptor in self.interceptors() { + if let Err(new_error) = interceptor.$interceptor(&mut ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::$interceptor) + } }; - ($outer_name:ident, $inner_name:ident ($context:ident : $context_ty:ty)) => { - pub fn $outer_name( + (ref $interceptor:ident) => { + pub fn $interceptor( &self, - $context: $context_ty, + ctx: &InterceptorContext, cfg: &mut ConfigBag, ) -> Result<(), InterceptorError> { let mut result: Result<(), BoxError> = Ok(()); + let ctx = ctx.into(); for interceptor in self.interceptors() { - if let Err(new_error) = interceptor.$inner_name($context, cfg) { + if let Err(new_error) = interceptor.$interceptor(&ctx, cfg) { if let Err(last_error) = result { tracing::debug!("{}", DisplayErrorContext(&*last_error)); } result = Err(new_error); } } - result.map_err(InterceptorError::$inner_name) + result.map_err(InterceptorError::$interceptor) } }; } @@ -679,48 +694,128 @@ impl Interceptors { &mut self.operation_interceptors } - interceptor_impl_fn!( - context, - client_read_before_execution, - read_before_execution, - BeforeSerialization - ); - interceptor_impl_fn!( - context, - operation_read_before_execution, - read_before_execution, - BeforeSerialization - ); - interceptor_impl_fn!( - mut context, - modify_before_serialization, - BeforeSerialization - ); - interceptor_impl_fn!(context, read_before_serialization, BeforeSerialization); - interceptor_impl_fn!(context, read_after_serialization, BeforeTransmit); - interceptor_impl_fn!(mut context, modify_before_retry_loop, BeforeTransmit); - interceptor_impl_fn!(context, read_before_attempt, BeforeTransmit); - interceptor_impl_fn!(mut context, modify_before_signing, BeforeTransmit); - interceptor_impl_fn!(context, read_before_signing, BeforeTransmit); - interceptor_impl_fn!(context, read_after_signing, BeforeTransmit); - interceptor_impl_fn!(mut context, modify_before_transmit, BeforeTransmit); - interceptor_impl_fn!(context, read_before_transmit, BeforeTransmit); - interceptor_impl_fn!(context, read_after_transmit, BeforeDeserialization); - interceptor_impl_fn!( - mut context, - modify_before_deserialization, - BeforeDeserialization - ); - interceptor_impl_fn!(context, read_before_deserialization, BeforeDeserialization); - interceptor_impl_fn!(context, read_after_deserialization, AfterDeserialization); - interceptor_impl_fn!( - mut context, - modify_before_attempt_completion, - AfterDeserialization - ); - interceptor_impl_fn!(context, read_after_attempt, AfterDeserialization); - interceptor_impl_fn!(mut context, modify_before_completion, AfterDeserialization); - interceptor_impl_fn!(context, read_after_execution, AfterDeserialization); + pub fn client_read_before_execution( + &self, + ctx: &InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into(); + for interceptor in self.client_interceptors.0.iter() { + if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::read_before_execution) + } + + pub fn operation_read_before_execution( + &self, + ctx: &InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into(); + for interceptor in self.operation_interceptors.0.iter() { + if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::read_before_execution) + } + + interceptor_impl_fn!(mut modify_before_serialization); + interceptor_impl_fn!(ref read_before_serialization); + interceptor_impl_fn!(ref read_after_serialization); + interceptor_impl_fn!(mut modify_before_retry_loop); + interceptor_impl_fn!(ref read_before_attempt); + interceptor_impl_fn!(mut modify_before_signing); + interceptor_impl_fn!(ref read_before_signing); + interceptor_impl_fn!(ref read_after_signing); + interceptor_impl_fn!(mut modify_before_transmit); + interceptor_impl_fn!(ref read_before_transmit); + interceptor_impl_fn!(ref read_after_transmit); + interceptor_impl_fn!(mut modify_before_deserialization); + interceptor_impl_fn!(ref read_before_deserialization); + interceptor_impl_fn!(ref read_after_deserialization); + + pub fn modify_before_attempt_completion( + &self, + ctx: &mut InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into(); + for interceptor in self.interceptors() { + if let Err(new_error) = interceptor.modify_before_attempt_completion(&mut ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::modify_before_attempt_completion) + } + + pub fn read_after_attempt( + &self, + ctx: &InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let ctx: FinalizerInterceptorContextRef<'_> = ctx.into(); + for interceptor in self.interceptors() { + if let Err(new_error) = interceptor.read_after_attempt(&ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::read_after_attempt) + } + + pub fn modify_before_completion( + &self, + ctx: &mut InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into(); + for interceptor in self.interceptors() { + if let Err(new_error) = interceptor.modify_before_completion(&mut ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::modify_before_completion) + } + + pub fn read_after_execution( + &self, + ctx: &InterceptorContext, + cfg: &mut ConfigBag, + ) -> Result<(), InterceptorError> { + let mut result: Result<(), BoxError> = Ok(()); + let ctx: FinalizerInterceptorContextRef<'_> = ctx.into(); + for interceptor in self.interceptors() { + if let Err(new_error) = interceptor.read_after_execution(&ctx, cfg) { + if let Err(last_error) = result { + tracing::debug!("{}", DisplayErrorContext(&*last_error)); + } + result = Err(new_error); + } + } + result.map_err(InterceptorError::read_after_execution) + } } #[cfg(test)] diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs index 309b1ca9d..82481f437 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs @@ -7,184 +7,85 @@ //! //! Interceptors have access to varying pieces of context during the course of an operation. //! -//! An operation is composed of multiple phases. The initial phase is [`phase::BeforeSerialization`], which -//! has the original input as context. The next phase is [`phase::BeforeTransmit`], which has the serialized +//! An operation is composed of multiple phases. The initial phase is [`Phase::BeforeSerialization`], which +//! has the original input as context. The next phase is [`Phase::BeforeTransmit`], which has the serialized //! request as context. Depending on which hook is being called with the dispatch context, //! the serialized request may or may not be signed (which should be apparent from the hook name). -//! Following the [`phase::BeforeTransmit`] phase is the [`phase::BeforeDeserialization`] phase, which has -//! the raw response available as context. Finally, the [`phase::AfterDeserialization`] phase +//! Following the [`Phase::BeforeTransmit`] phase is the [`Phase::BeforeDeserialization`] phase, which has +//! the raw response available as context. Finally, the [`Phase::AfterDeserialization`] phase //! has both the raw and parsed response available. //! //! To summarize: -//! 1. [`phase::BeforeSerialization`]: Only has the operation input. -//! 2. [`phase::BeforeTransmit`]: Only has the serialized request. -//! 3. [`phase::BeforeDeserialization`]: Has the raw response. -//! 3. [`phase::AfterDeserialization`]: Has the raw response and the parsed response. +//! 1. [`Phase::BeforeSerialization`]: Only has the operation input. +//! 2. [`Phase::BeforeTransmit`]: Only has the serialized request. +//! 3. [`Phase::BeforeDeserialization`]: Has the raw response. +//! 3. [`Phase::AfterDeserialization`]: Has the raw response and the parsed response. //! //! When implementing hooks, if information from a previous phase is required, then implement //! an earlier hook to examine that context, and save off any necessary information into the -//! [`crate::config_bag::ConfigBag`] for later hooks to examine. Interior mutability is **NOT** +//! [`ConfigBag`] for later hooks to examine. Interior mutability is **NOT** //! recommended for storing request-specific information in your interceptor implementation. -//! Use the [`crate::config_bag::ConfigBag`] instead. +//! Use the [`ConfigBag`] instead. -use crate::client::interceptors::BoxError; -use crate::client::orchestrator::{HttpRequest, HttpResponse}; +/// Operation phases. +pub mod phase; +pub mod wrappers; + +use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError}; use crate::config_bag::ConfigBag; use crate::type_erasure::{TypeErasedBox, TypeErasedError}; use aws_smithy_http::result::SdkError; +use phase::Phase; +use std::fmt::Debug; +use std::mem; +use tracing::{error, trace}; pub type Input = TypeErasedBox; pub type Output = TypeErasedBox; pub type Error = TypeErasedError; -pub type OutputOrError = Result; +pub type OutputOrError = Result>; type Request = HttpRequest; type Response = HttpResponse; -/// Operation phases. -pub mod phase { - use crate::client::interceptors::context::{Error, Output}; - use crate::client::interceptors::BoxError; - use crate::client::orchestrator::HttpResponse; - use aws_smithy_http::result::{ConnectorError, SdkError}; - - macro_rules! impl_phase { - ($phase:ty, $convert_err:ident) => { - impl Phase for $phase { - fn convert_error( - &self, - error: BoxError, - output_or_error: Option>, - response: Option, - ) -> SdkError { - $convert_err(error, output_or_error, response) - } - } - }; - } - - #[doc(hidden)] - pub trait Phase { - fn convert_error( - &self, - error: BoxError, - output_or_error: Option>, - response: Option, - ) -> SdkError; - } - - fn convert_construction_failure( - error: BoxError, - _: Option>, - _: Option, - ) -> SdkError { - SdkError::construction_failure(error) - } - - fn convert_dispatch_error( - error: BoxError, - _: Option>, - response: Option, - ) -> SdkError { - let error = match error.downcast::() { - Ok(connector_error) => { - return SdkError::dispatch_failure(*connector_error); - } - Err(e) => e, - }; - if let Some(response) = response { - SdkError::response_error(error, response) - } else { - SdkError::dispatch_failure(ConnectorError::other(error, None)) - } - } - - fn convert_response_handling_error( - error: BoxError, - output_or_error: Option>, - response: Option, - ) -> SdkError { - match (response, output_or_error) { - (Some(response), Some(Err(error))) => SdkError::service_error(error, response), - (Some(response), _) => SdkError::response_error(error, response), - _ => unreachable!("phase has a response"), - } - } - - /// Represents the phase of an operation prior to serialization. - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct BeforeSerialization; - impl_phase!(BeforeSerialization, convert_construction_failure); - - #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct Serialization; - impl_phase!(Serialization, convert_construction_failure); - - /// Represents the phase of an operation prior to transmitting a request over the network. - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct BeforeTransmit; - impl_phase!(BeforeTransmit, convert_dispatch_error); - - #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct Transmit; - impl_phase!(Transmit, convert_dispatch_error); - - /// Represents the phase of an operation after receiving a response, but before parsing that response. - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct BeforeDeserialization; - impl_phase!(BeforeDeserialization, convert_response_handling_error); - - #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct Deserialization; - impl_phase!(Deserialization, convert_response_handling_error); - - /// Represents the phase of an operation after parsing a response. - #[derive(Default, Debug)] - #[non_exhaustive] - pub struct AfterDeserialization; - impl_phase!(AfterDeserialization, convert_response_handling_error); -} - /// A container for the data currently available to an interceptor. /// /// Different context is available based on which phase the operation is currently in. For example, -/// context in the [`phase::BeforeSerialization`] phase won't have a `request` yet since the input hasn't been -/// serialized at that point. But once it gets into the [`phase::BeforeTransmit`] phase, the `request` will be set. -pub struct InterceptorContext { - input: Option, - output_or_error: Option>, - request: Option, - response: Option, +/// context in the [`Phase::BeforeSerialization`] phase won't have a `request` yet since the input hasn't been +/// serialized at that point. But once it gets into the [`Phase::BeforeTransmit`] phase, the `request` will be set. +#[derive(Debug)] +pub struct InterceptorContext +where + E: Debug, +{ + pub(crate) input: Option, + pub(crate) output_or_error: Option>>, + pub(crate) request: Option, + pub(crate) response: Option, phase: Phase, + tainted: bool, + request_checkpoint: Option, } -// -// All phases -// -impl InterceptorContext<(), Input, Output, Error> { - /// Creates a new interceptor context in the [`phase::BeforeSerialization`] phase. - pub fn new( - input: Input, - ) -> InterceptorContext { +impl InterceptorContext { + /// Creates a new interceptor context in the [`Phase::BeforeSerialization`] phase. + pub fn new(input: Input) -> InterceptorContext { InterceptorContext { input: Some(input), output_or_error: None, request: None, response: None, - phase: Default::default(), + phase: Phase::BeforeSerialization, + tainted: false, + request_checkpoint: None, } } } -impl InterceptorContext { + +impl InterceptorContext +where + E: Debug, +{ /// Decomposes the context into its constituent parts. #[doc(hidden)] #[allow(clippy::type_complexity)] @@ -192,114 +93,60 @@ impl InterceptorContext { self, ) -> ( Option, - Option>, + Option>>, Option, Option, - Phase, ) { ( self.input, self.output_or_error, self.request, self.response, - self.phase, ) } -} -// -// BeforeSerialization phase methods -// -impl InterceptorContext { + pub fn finalize(self) -> Result> { + let Self { + output_or_error, + response, + phase, + .. + } = self; + output_or_error + .expect("output_or_error must always beset before finalize is called.") + .map_err(|error| OrchestratorError::into_sdk_error(error, &phase, response)) + } + /// Retrieve the input for the operation being invoked. pub fn input(&self) -> &I { self.input .as_ref() - .expect("input is present in phase::BeforeSerialization") + .expect("input is present in 'before serialization'") } /// Retrieve the input for the operation being invoked. pub fn input_mut(&mut self) -> &mut I { self.input .as_mut() - .expect("input is present in phase::BeforeSerialization") - } - - /// Advance to the next phase. - #[doc(hidden)] - pub fn into_serialization_phase(self) -> InterceptorContext { - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: phase::Serialization::default(), - } + .expect("input is present in 'before serialization'") } -} -// -// Serialization phase methods -// -impl InterceptorContext { /// Takes ownership of the input. pub fn take_input(&mut self) -> Option { self.input.take() } + /// Set the request for the operation being invoked. pub fn set_request(&mut self, request: Request) { - debug_assert!( - self.request.is_none(), - "called set_request but a request was already set" - ); self.request = Some(request); } - /// Advance to the next phase. - #[doc(hidden)] - pub fn into_before_transmit_phase(self) -> InterceptorContext { - debug_assert!( - self.input.is_none(), - "input must be taken before going into phase::BeforeTransmit" - ); - debug_assert!( - self.request.is_some(), - "request must be set before going into phase::BeforeTransmit" - ); - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: Default::default(), - } - } -} - -// -// BeforeTransmit phase methods -// -impl InterceptorContext { - /// Creates a new interceptor context in the [`phase::BeforeTransmit`] phase. - pub fn new( - input: Option, - request: HttpRequest, - ) -> InterceptorContext { - InterceptorContext { - input, - output_or_error: None, - request: Some(request), - response: None, - phase: Default::default(), - } - } - /// Retrieve the transmittable request for the operation being invoked. /// This will only be available once request marshalling has completed. pub fn request(&self) -> &Request { self.request .as_ref() - .expect("request populated in phase::BeforeTransmit") + .expect("request populated in 'before transmit'") } /// Retrieve the transmittable request for the operation being invoked. @@ -307,266 +154,135 @@ impl InterceptorContext { pub fn request_mut(&mut self) -> &mut Request { self.request .as_mut() - .expect("request populated in phase::BeforeTransmit") + .expect("request populated in 'before transmit'") } - #[doc(hidden)] - pub fn into_transmit_phase(self) -> InterceptorContext { - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: Default::default(), - } - } -} - -// -// Transmit phase methods -// -impl InterceptorContext { /// Takes ownership of the request. - #[doc(hidden)] pub fn take_request(&mut self) -> Request { - debug_assert!(self.request.is_some()); self.request .take() - .expect("take request once during transmit") + .expect("take request once during 'transmit'") } - #[doc(hidden)] + /// Set the response for the operation being invoked. pub fn set_response(&mut self, response: Response) { - debug_assert!( - self.response.is_none(), - "called set_response but a response was already set" - ); self.response = Some(response); } - #[doc(hidden)] - pub fn into_before_deserialization_phase( - self, - ) -> InterceptorContext { - debug_assert!( - self.request.is_none(), - "request must be taken before going into phase::BeforeDeserialization" - ); - debug_assert!( - self.response.is_some(), - "response must be set to before going into phase::BeforeDeserialization" - ); - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: Default::default(), - } - } -} - -impl InterceptorContext { - /// Returns the response. - pub fn response(&self) -> &Response { - self.response - .as_ref() - .expect("response set in phase::BeforeDeserialization") - } - - /// Returns a mutable reference to the response. - pub fn response_mut(&mut self) -> &mut Response { - self.response - .as_mut() - .expect("response set in phase::BeforeDeserialization") - } - - #[doc(hidden)] - pub fn into_deserialization_phase(self) -> InterceptorContext { - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: Default::default(), - } - } -} - -impl InterceptorContext { /// Returns the response. pub fn response(&self) -> &Response { - self.response - .as_ref() - .expect("response set in phase::Deserialization") + self.response.as_ref().expect( + "response set in 'before deserialization' and available in the phases following it", + ) } /// Returns a mutable reference to the response. pub fn response_mut(&mut self) -> &mut Response { - self.response - .as_mut() - .expect("response set in phase::Deserialization") + self.response.as_mut().expect( + "response is set in 'before deserialization' and available in the following phases", + ) } - #[doc(hidden)] - pub fn set_output_or_error(&mut self, output: Result) { - debug_assert!(self.output_or_error.is_none()); + /// Set the output or error for the operation being invoked. + pub fn set_output_or_error(&mut self, output: Result>) { self.output_or_error = Some(output); } - #[doc(hidden)] - pub fn into_after_deserialization_phase( - self, - ) -> InterceptorContext { - debug_assert!( - self.output_or_error.is_some(), - "output must be set to before going into phase::AfterDeserialization" - ); - InterceptorContext { - input: self.input, - output_or_error: self.output_or_error, - request: self.request, - response: self.response, - phase: Default::default(), - } - } -} - -impl InterceptorContext { - /// Returns the response. - pub fn response(&self) -> &Response { - self.response - .as_ref() - .expect("response set in phase::BeforeDeserialization") - } - - /// Returns a mutable reference to the response. - pub fn response_mut(&mut self) -> &mut Response { - self.response - .as_mut() - .expect("response set in phase::BeforeDeserialization") - } - /// Returns the deserialized output or error. - pub fn output_or_error(&self) -> Result<&O, &E> { + pub fn output_or_error(&self) -> Result<&O, &OrchestratorError> { self.output_or_error .as_ref() - .expect("output set in phase::AfterDeserialization") + .expect("output set in Phase::AfterDeserialization") .as_ref() } /// Returns the mutable reference to the deserialized output or error. - pub fn output_or_error_mut(&mut self) -> &mut Result { + pub fn output_or_error_mut(&mut self) -> &mut Result> { self.output_or_error .as_mut() - .expect("output set in phase::AfterDeserialization") + .expect("output set in 'after deserialization'") } + /// Advance to the Serialization phase. #[doc(hidden)] - pub fn finalize(self) -> Result> { - self.output_or_error - .expect("output already populated in the response handling phase") - .map_err(|error| { - SdkError::service_error( - error, - self.response - .expect("raw response already populated in the response handling phase"), - ) - }) - } -} - -// This isn't great since it relies on a lot of runtime checking, but the -// compiler doesn't exactly make it easy to handle phase changes in a `loop`. -#[doc(hidden)] -pub struct AttemptCheckpoint { - tainted: bool, - checkpointed_request: Option, - before_transmit: Option>, - transmit: Option>, - before_deserialization: Option>, - deserialization: Option>, - after_deserialization: Option>, -} - -impl AttemptCheckpoint { - pub fn new(before_transmit: InterceptorContext) -> Self { - Self { - tainted: false, - checkpointed_request: Self::try_clone(before_transmit.request()), - before_transmit: Some(before_transmit), - transmit: None, - before_deserialization: None, - deserialization: None, - after_deserialization: None, - } + pub fn enter_serialization_phase(&mut self) { + debug_assert!( + self.phase.is_before_serialization(), + "called enter_serialization_phase but phase is not before 'serialization'" + ); + self.phase = Phase::Serialization; } - pub fn before_transmit(&mut self) -> &mut InterceptorContext { + /// Advance to the BeforeTransmit phase. + #[doc(hidden)] + pub fn enter_before_transmit_phase(&mut self) { + debug_assert!( + self.phase.is_serialization(), + "called enter_before_transmit_phase but phase is not 'serialization'" + ); + debug_assert!( + self.input.is_none(), + "input must be taken before calling enter_before_transmit_phase" + ); + debug_assert!( + self.request.is_some(), + "request must be set before calling enter_before_transmit_phase" + ); + self.request_checkpoint = try_clone(self.request()); self.tainted = true; - self.before_transmit - .as_mut() - .expect("must be in the before transmit phase") + self.phase = Phase::BeforeTransmit; } - pub fn transmit(&mut self) -> &mut InterceptorContext { - self.transmit - .as_mut() - .expect("must be in the transmit phase") - } - - pub fn before_deser(&mut self) -> &mut InterceptorContext { - self.before_deserialization - .as_mut() - .expect("must be in the before deserialization phase") - } - - pub fn deser(&mut self) -> &mut InterceptorContext { - self.deserialization - .as_mut() - .expect("must be in the deserialization phase") - } - - pub fn after_deser(&mut self) -> &mut InterceptorContext { - self.after_deserialization - .as_mut() - .expect("must be in the after deserialization phase") - } - - pub fn transition_to_transmit(&mut self) { - self.transmit = Some( - self.before_transmit - .take() - .expect("must be in the before transmit phase") - .into_transmit_phase(), + /// Advance to the Transmit phase. + #[doc(hidden)] + pub fn enter_transmit_phase(&mut self) { + debug_assert!( + self.phase.is_before_transmit(), + "called enter_transmit_phase but phase is not before transmit" ); + self.phase = Phase::Transmit; } - pub fn transition_to_deserialization(&mut self) { - self.deserialization = Some( - self.before_deserialization - .take() - .expect("must be in the before deserialization phase") - .into_deserialization_phase(), - ) + /// Advance to the BeforeDeserialization phase. + #[doc(hidden)] + pub fn enter_before_deserialization_phase(&mut self) { + debug_assert!( + self.phase.is_transmit(), + "called enter_before_deserialization_phase but phase is not 'transmit'" + ); + debug_assert!( + self.request.is_none(), + "request must be taken before entering the 'before deserialization' phase" + ); + debug_assert!( + self.response.is_some(), + "response must be set to before entering the 'before deserialization' phase" + ); + self.phase = Phase::BeforeDeserialization; } - pub fn transition_to_before_deserialization(&mut self) { - self.before_deserialization = Some( - self.transmit - .take() - .expect("must be in the transmit phase") - .into_before_deserialization_phase(), - ) + /// Advance to the Deserialization phase. + #[doc(hidden)] + pub fn enter_deserialization_phase(&mut self) { + debug_assert!( + self.phase.is_before_deserialization(), + "called enter_deserialization_phase but phase is not 'before deserialization'" + ); + self.phase = Phase::Deserialization; } - pub fn transition_to_after_deserialization(&mut self) { - self.after_deserialization = Some( - self.deserialization - .take() - .expect("must be in the deserialization phase") - .into_after_deserialization_phase(), - ) + /// Advance to the AfterDeserialization phase. + #[doc(hidden)] + pub fn enter_after_deserialization_phase(&mut self) { + debug_assert!( + self.phase.is_deserialization(), + "called enter_after_deserialization_phase but phase is not 'deserialization'" + ); + debug_assert!( + self.output_or_error.is_some(), + "output must be set to before entering the 'after deserialization' phase" + ); + self.phase = Phase::AfterDeserialization; } // Returns false if rewinding isn't possible @@ -575,73 +291,55 @@ impl AttemptCheckpoint { if !self.tainted { return true; } - // If checkpointed_request was never set, then this is not a retryable request - if self.checkpointed_request.is_none() { + // If request_checkpoint was never set, then this is not a retryable request + if self.request_checkpoint.is_none() { return false; } // Otherwise, rewind back to the beginning of BeforeTransmit // TODO(enableNewSmithyRuntime): Also rewind the ConfigBag - fn into_input(context: InterceptorContext

) -> Option { - context.into_parts().0 - } - // Take the input from the current phase - let input = None - .or(self.before_transmit.take().map(into_input)) - .or(self.transmit.take().map(into_input)) - .or(self.before_deserialization.take().map(into_input)) - .or(self.deserialization.take().map(into_input)) - .or(self.after_deserialization.take().map(into_input)) - .expect("at least one phase must be in progress"); - let fresh_request = - Self::try_clone(self.checkpointed_request.as_ref().expect("checked above")) - .expect("cloneable request"); - self.before_transmit = Some(InterceptorContext::::new( - input, - fresh_request, - )); + self.phase = Phase::BeforeTransmit; + self.request = try_clone(self.request_checkpoint.as_ref().expect("checked above")); + self.response = None; + self.output_or_error = None; true } - pub fn into_error(self, reason: BoxError) -> SdkError { - fn err( - context: InterceptorContext

, - ) -> Box SdkError> { - Box::new(move |reason| { - let (_input, output_or_error, _request, response, phase) = context.into_parts(); - phase.convert_error(reason, output_or_error, response) - }) + /// Mark this context as failed due to errors during the operation. Any errors already contained + /// by the context will be replaced by the given error. + pub fn fail(&mut self, error: OrchestratorError) { + if !self.is_failed() { + trace!( + "orchestrator is transitioning to the 'failure' phase from the '{:?}' phase", + self.phase + ); + } + if let Some(Err(existing_err)) = mem::replace(&mut self.output_or_error, Some(Err(error))) { + error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err); } - // Convert the current phase into an error - (None - .or(self.before_transmit.map(err)) - .or(self.transmit.map(err)) - .or(self.before_deserialization.map(err)) - .or(self.deserialization.map(err)) - .or(self.after_deserialization.map(err)) - .expect("at least one phase must be in progress"))(reason) } - pub fn finalize(self) -> Result> { - self.after_deserialization - .expect("must be in the after deserialization phase") - .finalize() + /// Return `true` if this context's `output_or_error` is an error. Otherwise, return `false`. + pub fn is_failed(&self) -> bool { + self.output_or_error + .as_ref() + .map(Result::is_err) + .unwrap_or_default() } +} - pub fn try_clone(request: &HttpRequest) -> Option { - let cloned_body = request.body().try_clone()?; - let mut cloned_request = ::http::Request::builder() - .uri(request.uri().clone()) - .method(request.method()); - *cloned_request - .headers_mut() - .expect("builder has not been modified, headers must be valid") = - request.headers().clone(); - Some( - cloned_request - .body(cloned_body) - .expect("a clone of a valid request should be a valid request"), - ) - } +fn try_clone(request: &HttpRequest) -> Option { + let cloned_body = request.body().try_clone()?; + let mut cloned_request = ::http::Request::builder() + .uri(request.uri().clone()) + .method(request.method()); + *cloned_request + .headers_mut() + .expect("builder has not been modified, headers must be valid") = request.headers().clone(); + Some( + cloned_request + .body(cloned_body) + .expect("a clone of a valid request should be a valid request"), + ) } #[cfg(test)] @@ -657,40 +355,38 @@ mod tests { let input = TypedBox::new("input".to_string()).erase(); let output = TypedBox::new("output".to_string()).erase(); - let mut context = InterceptorContext::<()>::new(input); + let mut context = InterceptorContext::new(input); assert_eq!("input", context.input().downcast_ref::().unwrap()); context.input_mut(); - let mut context = context.into_serialization_phase(); + context.enter_serialization_phase(); let _ = context.take_input(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); - let mut checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); - checkpoint.before_transmit().request(); - checkpoint.before_transmit().request_mut(); + context.enter_before_transmit_phase(); + context.request(); + context.request_mut(); - checkpoint.transition_to_transmit(); - let _ = checkpoint.transmit().take_request(); - checkpoint - .transmit() - .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + context.enter_transmit_phase(); + let _ = context.take_request(); + context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); - checkpoint.transition_to_before_deserialization(); - checkpoint.before_deser().response(); - checkpoint.before_deser().response_mut(); + context.enter_before_deserialization_phase(); + context.response(); + context.response_mut(); - checkpoint.transition_to_deserialization(); - checkpoint.deser().response(); - checkpoint.deser().response_mut(); - checkpoint.deser().set_output_or_error(Ok(output)); + context.enter_deserialization_phase(); + context.response(); + context.response_mut(); + context.set_output_or_error(Ok(output)); - checkpoint.transition_to_after_deserialization(); - checkpoint.after_deser().response(); - checkpoint.after_deser().response_mut(); - let _ = checkpoint.after_deser().output_or_error(); - let _ = checkpoint.after_deser().output_or_error_mut(); + context.enter_after_deserialization_phase(); + context.response(); + context.response_mut(); + let _ = context.output_or_error(); + let _ = context.output_or_error_mut(); - let output = checkpoint.finalize().expect("success"); + let output = context.output_or_error.unwrap().expect("success"); assert_eq!("output", output.downcast_ref::().unwrap()); } @@ -701,7 +397,7 @@ mod tests { struct Error; impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("dontcare") + f.write_str("don't care") } } impl std::error::Error for Error {} @@ -711,90 +407,72 @@ mod tests { let output = TypedBox::new("output".to_string()).erase(); let error = TypedBox::new(Error).erase_error(); - let context = InterceptorContext::<()>::new(input); + let mut context = InterceptorContext::new(input); assert_eq!("input", context.input().downcast_ref::().unwrap()); - let mut context = context.into_serialization_phase(); + context.enter_serialization_phase(); let _ = context.take_input(); context.set_request( http::Request::builder() - .header("test", "the-original-unmutated-request") + .header("test", "the-original-un-mutated-request") .body(SdkBody::empty()) .unwrap(), ); - - let mut checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); + context.enter_before_transmit_phase(); // Modify the test header post-checkpoint to simulate modifying the request for signing or a mutating interceptor - checkpoint - .before_transmit() - .request_mut() - .headers_mut() - .remove("test"); - checkpoint - .before_transmit() - .request_mut() - .headers_mut() - .insert( - "test", - HeaderValue::from_static("request-modified-after-signing"), - ); + context.request_mut().headers_mut().remove("test"); + context.request_mut().headers_mut().insert( + "test", + HeaderValue::from_static("request-modified-after-signing"), + ); - checkpoint.transition_to_transmit(); - let request = checkpoint.transmit().take_request(); + context.enter_transmit_phase(); + let request = context.take_request(); assert_eq!( "request-modified-after-signing", request.headers().get("test").unwrap() ); - checkpoint - .transmit() - .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); - checkpoint.transition_to_before_deserialization(); - checkpoint.transition_to_deserialization(); - checkpoint.deser().set_output_or_error(Err(error)); + context.enter_before_deserialization_phase(); + context.enter_deserialization_phase(); + context.set_output_or_error(Err(OrchestratorError::operation(error))); - assert!(checkpoint.rewind(&mut cfg)); + assert!(context.rewind(&mut cfg)); // Now after rewinding, the test header should be its original value assert_eq!( - "the-original-unmutated-request", - checkpoint - .before_transmit() - .request() - .headers() - .get("test") - .unwrap() + "the-original-un-mutated-request", + context.request().headers().get("test").unwrap() ); - checkpoint.transition_to_transmit(); - let _ = checkpoint.transmit().take_request(); - checkpoint - .transmit() - .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + context.enter_transmit_phase(); + let _ = context.take_request(); + context.set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); - checkpoint.transition_to_before_deserialization(); - checkpoint.transition_to_deserialization(); - checkpoint.deser().set_output_or_error(Ok(output)); + context.enter_before_deserialization_phase(); + context.enter_deserialization_phase(); + context.set_output_or_error(Ok(output)); - checkpoint.transition_to_after_deserialization(); + context.enter_after_deserialization_phase(); - let output = checkpoint.finalize().expect("success"); + let output = context.output_or_error.unwrap().expect("success"); assert_eq!("output", output.downcast_ref::().unwrap()); } #[test] fn try_clone_clones_all_data() { let request = ::http::Request::builder() - .uri(Uri::from_static("http://www.amazon.com")) + .uri(Uri::from_static("https://www.amazon.com")) .method("POST") .header(CONTENT_LENGTH, 456) .header(AUTHORIZATION, "Token: hello") .body(SdkBody::from("hello world!")) .expect("valid request"); - let cloned = AttemptCheckpoint::try_clone(&request).expect("request is cloneable"); + let cloned = try_clone(&request).expect("request is cloneable"); - assert_eq!(&Uri::from_static("http://www.amazon.com"), cloned.uri()); + assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri()); assert_eq!("POST", cloned.method()); assert_eq!(2, cloned.headers().len()); assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),); diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/phase.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/phase.rs new file mode 100644 index 000000000..01c7f2f12 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/phase.rs @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[derive(Debug)] +#[non_exhaustive] +pub enum Phase { + /// Represents the phase of an operation prior to serialization. + BeforeSerialization, + /// Represents the phase of an operation where the request is serialized. + Serialization, + /// Represents the phase of an operation prior to transmitting a request over the network. + BeforeTransmit, + /// Represents the phase of an operation where the request is transmitted over the network. + Transmit, + /// Represents the phase of an operation prior to parsing a response. + BeforeDeserialization, + /// Represents the phase of an operation where the response is parsed. + Deserialization, + /// Represents the phase of an operation after parsing a response. + AfterDeserialization, +} + +impl Phase { + pub fn is_before_serialization(&self) -> bool { + matches!(self, Self::BeforeSerialization) + } + + pub fn is_serialization(&self) -> bool { + matches!(self, Self::Serialization) + } + + pub fn is_before_transmit(&self) -> bool { + matches!(self, Self::BeforeTransmit) + } + + pub fn is_transmit(&self) -> bool { + matches!(self, Self::Transmit) + } + + pub fn is_before_deserialization(&self) -> bool { + matches!(self, Self::BeforeDeserialization) + } + + pub fn is_deserialization(&self) -> bool { + matches!(self, Self::Deserialization) + } + + pub fn is_after_deserialization(&self) -> bool { + matches!(self, Self::AfterDeserialization) + } +} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/wrappers.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/wrappers.rs new file mode 100644 index 000000000..b9db1e404 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context/wrappers.rs @@ -0,0 +1,253 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use super::{Error, Input, InterceptorContext, Output}; +use crate::client::interceptors::context::{Request, Response}; +use crate::client::orchestrator::OrchestratorError; +use std::fmt::Debug; + +macro_rules! output { + (&Option>) => { + Option> + }; + (&Option<$ty:ty>) => { + Option<&$ty> + }; + (&mut Option<$ty:ty>) => { + Option<&mut $ty> + }; + (&Result<$o_ty:ty, $e_ty:ty>) => { + Result<&$o_ty, &$e_ty> + }; + (&$($tt:tt)+) => { + &$($tt)+ + }; + (&mut $($tt:tt)+) => { + &mut $($tt)+ + }; +} + +macro_rules! declare_method { + (&mut $name:ident, $inner_name:ident, $doc:literal, Option<$ty:ty>) => { + #[doc=$doc] + pub fn $name(&mut self) -> Option<&mut $ty> { + self.inner.$inner_name.as_ref() + } + }; + (&$name:ident, $inner_name:ident, $doc:literal, Option<$ty:ty>) => { + #[doc=$doc] + pub fn $name(&self) -> Option<$ty> { + self.inner.$inner_name.as_mut() + } + }; + (&mut $name:ident, $doc:literal, $($tt:tt)+) => { + #[doc=$doc] + pub fn $name(&mut self) -> output!(&mut $($tt)+) { + self.inner.$name() + } + }; + (&$name:ident, $doc:literal, $($tt:tt)+) => { + #[doc=$doc] + pub fn $name(&self) -> output!(&$($tt)+) { + self.inner.$name() + } + }; +} + +macro_rules! declare_known_method { + (output_or_error: &mut $($tt:tt)+) => { + declare_method!(&mut output_or_error_mut, "Returns a mutable reference to the deserialized output or error.", $($tt)+); + }; + (output_or_error: &$($tt:tt)+) => { + declare_method!(&output_or_error, "Returns a reference to the deserialized output or error.", $($tt)+); + }; + (input: &mut $($tt:tt)+) => { + declare_method!(&mut input_mut, "Returns a mutable reference to the input.", $($tt)+); + }; + (input: &$($tt:tt)+) => { + declare_method!(&input, "Returns a reference to the input.", $($tt)+); + }; + (request: &mut $($tt:tt)+) => { + declare_method!(&mut request_mut, "Returns a mutable reference to the transmittable request for the operation being invoked.", $($tt)+); + }; + (request: &$($tt:tt)+) => { + declare_method!(&request, "Returns a reference to the transmittable request for the operation being invoked.", $($tt)+); + }; + (response: &mut $($tt:tt)+) => { + declare_method!(&mut response_mut, "Returns a mutable reference to the response.", $($tt)+); + }; + (response: &$($tt:tt)+) => { + declare_method!(&response, "Returns a reference to the response.", $($tt)+); + }; +} + +macro_rules! declare_wrapper { + (($ref_struct_name:ident $mut_struct_name:ident)$($tt:tt)+) => { + pub struct $ref_struct_name<'a, I = Input, O = Output, E = Error> + where E: Debug { + inner: &'a InterceptorContext, + } + + impl<'a, I, O, E: Debug> From<&'a InterceptorContext> for $ref_struct_name<'a, I, O, E> + { + fn from(inner: &'a InterceptorContext) -> Self { + Self { inner } + } + } + + impl<'a, I, O, E: Debug> $ref_struct_name<'a, I, O, E> { + declare_ref_wrapper_methods!($($tt)+); + } + + pub struct $mut_struct_name<'a, I = Input, O = Output, E = Error> + where E: Debug { + inner: &'a mut InterceptorContext, + } + + impl<'a, I, O, E: Debug> From<&'a mut InterceptorContext> for $mut_struct_name<'a, I, O, E> + { + fn from(inner: &'a mut InterceptorContext) -> Self { + Self { inner } + } + } + + impl<'a, I, O, E: Debug> $mut_struct_name<'a, I, O, E> { + declare_ref_wrapper_methods!($($tt)+); + declare_mut_wrapper_methods!($($tt)+); + } + }; +} + +macro_rules! declare_ref_wrapper_methods { + (($field:ident: $($head:tt)+)$($tail:tt)+) => { + declare_known_method!($field: &$($head)+); + declare_ref_wrapper_methods!($($tail)+); + }; + (($field:ident: $($tt:tt)+)) => { + declare_known_method!($field: &$($tt)+); + }; +} + +macro_rules! declare_mut_wrapper_methods { + (($field:ident: $($head:tt)+)$($tail:tt)+) => { + declare_known_method!($field: &mut $($head)+); + declare_mut_wrapper_methods!($($tail)+); + }; + (($field:ident: $($tt:tt)+)) => { + declare_known_method!($field: &mut $($tt)+); + }; +} + +declare_wrapper!( + (BeforeSerializationInterceptorContextRef BeforeSerializationInterceptorContextMut) + (input: I) +); + +declare_wrapper!( + (BeforeTransmitInterceptorContextRef BeforeTransmitInterceptorContextMut) + (input: I) + (request: Request) +); + +declare_wrapper!( + (BeforeDeserializationInterceptorContextRef BeforeDeserializationInterceptorContextMut) + (input: I) + (request: Request) + (response: Response) +); + +declare_wrapper!( + (AfterDeserializationInterceptorContextRef AfterDeserializationInterceptorContextMut) + (input: I) + (request: Request) + (response: Response) + (output_or_error: Result>) +); + +// Why are all the rest of these defined with a macro but these last two aren't? I simply ran out of +// time. Consider updating the macros to support these last two if you're looking for a challenge. +// - Zelda + +pub struct FinalizerInterceptorContextRef<'a, I = Input, O = Output, E = Error> +where + E: Debug, +{ + inner: &'a InterceptorContext, +} + +impl<'a, I, O, E: Debug> From<&'a InterceptorContext> + for FinalizerInterceptorContextRef<'a, I, O, E> +{ + fn from(inner: &'a InterceptorContext) -> Self { + Self { inner } + } +} + +impl<'a, I, O, E: Debug> FinalizerInterceptorContextRef<'a, I, O, E> { + pub fn input(&self) -> Option<&I> { + self.inner.input.as_ref() + } + + pub fn request(&self) -> Option<&Request> { + self.inner.request.as_ref() + } + + pub fn response(&self) -> Option<&Response> { + self.inner.response.as_ref() + } + + pub fn output_or_error(&self) -> Option>> { + self.inner.output_or_error.as_ref().map(|o| o.as_ref()) + } +} + +pub struct FinalizerInterceptorContextMut<'a, I = Input, O = Output, E = Error> +where + E: Debug, +{ + inner: &'a mut InterceptorContext, +} + +impl<'a, I, O, E: Debug> From<&'a mut InterceptorContext> + for FinalizerInterceptorContextMut<'a, I, O, E> +{ + fn from(inner: &'a mut InterceptorContext) -> Self { + Self { inner } + } +} + +impl<'a, I, O, E: Debug> FinalizerInterceptorContextMut<'a, I, O, E> { + pub fn input(&self) -> Option<&I> { + self.inner.input.as_ref() + } + + pub fn request(&self) -> Option<&Request> { + self.inner.request.as_ref() + } + + pub fn response(&self) -> Option<&Response> { + self.inner.response.as_ref() + } + + pub fn output_or_error(&self) -> Option>> { + self.inner.output_or_error.as_ref().map(|o| o.as_ref()) + } + + pub fn input_mut(&mut self) -> Option<&mut I> { + self.inner.input.as_mut() + } + + pub fn request_mut(&mut self) -> Option<&mut Request> { + self.inner.request.as_mut() + } + + pub fn response_mut(&mut self) -> Option<&mut Response> { + self.inner.response.as_mut() + } + + pub fn output_or_error_mut(&mut self) -> Option<&mut Result>> { + self.inner.output_or_error.as_mut() + } +} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs index 8d8b82897..402626977 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs @@ -61,6 +61,11 @@ impl InterceptorError { interceptor_error_fn!(modify_before_completion => ModifyBeforeCompletion (with source)); interceptor_error_fn!(read_after_execution => ReadAfterExecution (with source)); + interceptor_error_fn!(modify_before_attempt_completion_failed => ModifyBeforeAttemptCompletion (with source)); + interceptor_error_fn!(read_after_attempt_failed => ReadAfterAttempt (with source)); + interceptor_error_fn!(modify_before_completion_failed => ModifyBeforeCompletion (with source)); + interceptor_error_fn!(read_after_execution_failed => ReadAfterExecution (with source)); + interceptor_error_fn!(invalid_request_access => InvalidRequestAccess (invalid request access)); interceptor_error_fn!(invalid_response_access => InvalidResponseAccess (invalid response access)); interceptor_error_fn!(invalid_input_access => InvalidInputAccess (invalid input access)); diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index 37c379a10..c2c9e9cd5 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -3,6 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ +/// Errors that can occur while running the orchestrator. +mod error; + use crate::client::auth::{AuthOptionResolver, AuthOptionResolverParams, HttpAuthSchemes}; use crate::client::identity::IdentityResolvers; use crate::client::interceptors::context::{Error, Input, Output}; @@ -20,6 +23,8 @@ use std::pin::Pin; use std::sync::Arc; use std::time::SystemTime; +pub use error::OrchestratorError; + pub type HttpRequest = http::Request; pub type HttpResponse = http::Response; pub type BoxError = Box; @@ -31,12 +36,18 @@ pub trait RequestSerializer: Send + Sync + fmt::Debug { } pub trait ResponseDeserializer: Send + Sync + fmt::Debug { - fn deserialize_streaming(&self, response: &mut HttpResponse) -> Option> { + fn deserialize_streaming( + &self, + response: &mut HttpResponse, + ) -> Option>> { let _ = response; None } - fn deserialize_nonstreaming(&self, response: &HttpResponse) -> Result; + fn deserialize_nonstreaming( + &self, + response: &HttpResponse, + ) -> Result>; } pub trait Connection: Send + Sync + fmt::Debug { diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs new file mode 100644 index 000000000..36b12d9d2 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator/error.rs @@ -0,0 +1,117 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use super::BoxError; +use crate::client::interceptors::context::phase::Phase; +use crate::client::interceptors::InterceptorError; +use crate::client::orchestrator::HttpResponse; +use crate::type_erasure::TypeErasedError; +use aws_smithy_http::result::{ConnectorError, SdkError}; +use std::fmt::Debug; + +#[derive(Debug)] +#[non_exhaustive] +pub enum OrchestratorError { + /// An error occurred within an interceptor. + Interceptor { err: InterceptorError }, + /// An error returned by a service. + Operation { err: E }, + /// A general orchestrator error. + Other { err: BoxError }, +} + +impl OrchestratorError { + /// Create a new `OrchestratorError` from a [`BoxError`]. + pub fn other(err: BoxError) -> Self { + Self::Other { err } + } + + /// Create a new `OrchestratorError` from an error received from a service. + pub fn operation(err: E) -> Self { + Self::Operation { err } + } + + /// Create a new `OrchestratorError` from an [`InterceptorError`]. + pub fn interceptor(err: InterceptorError) -> Self { + Self::Interceptor { err } + } + + /// Convert the `OrchestratorError` into an [`SdkError`]. + pub fn into_sdk_error( + self, + phase: &Phase, + response: Option, + ) -> SdkError { + match self { + Self::Interceptor { err } => { + use Phase::*; + match phase { + BeforeSerialization | Serialization => SdkError::construction_failure(err), + BeforeTransmit | Transmit => match response { + Some(response) => SdkError::response_error(err, response), + None => SdkError::dispatch_failure(ConnectorError::other(err.into(), None)), + }, + BeforeDeserialization | Deserialization | AfterDeserialization => { + SdkError::response_error(err, response.expect("phase has a response")) + } + } + } + Self::Operation { err } => { + debug_assert!(phase.is_after_deserialization(), "operation errors are a result of successfully receiving and parsing a response from the server. Therefore, we must be in the 'After Deserialization' phase."); + SdkError::service_error(err, response.expect("phase has a response")) + } + Self::Other { err } => { + use Phase::*; + match phase { + BeforeSerialization | Serialization => SdkError::construction_failure(err), + BeforeTransmit | Transmit => convert_dispatch_error(err, response), + BeforeDeserialization | Deserialization | AfterDeserialization => { + SdkError::response_error(err, response.expect("phase has a response")) + } + } + } + } + } +} + +fn convert_dispatch_error( + err: BoxError, + response: Option, +) -> SdkError { + let err = match err.downcast::() { + Ok(connector_error) => { + return SdkError::dispatch_failure(*connector_error); + } + Err(e) => e, + }; + match response { + Some(response) => SdkError::response_error(err, response), + None => SdkError::dispatch_failure(ConnectorError::other(err, None)), + } +} + +impl From for OrchestratorError +where + E: Debug + std::error::Error + 'static, +{ + fn from(err: InterceptorError) -> Self { + Self::interceptor(err) + } +} + +impl From for OrchestratorError +where + E: Debug + std::error::Error + 'static, +{ + fn from(err: BoxError) -> Self { + Self::other(err) + } +} + +impl From for OrchestratorError { + fn from(err: TypeErasedError) -> Self { + Self::operation(err) + } +} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs index 0d6678d35..fa42811ee 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs @@ -3,10 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::client::interceptors::context::phase::AfterDeserialization; use crate::client::interceptors::context::Error; use crate::client::interceptors::InterceptorContext; -use crate::client::orchestrator::BoxError; +use crate::client::orchestrator::{BoxError, OrchestratorError}; use crate::config_bag::ConfigBag; use aws_smithy_types::retry::ErrorKind; use std::fmt::Debug; @@ -24,7 +23,7 @@ pub trait RetryStrategy: Send + Sync + Debug { fn should_attempt_retry( &self, - context: &InterceptorContext, + context: &InterceptorContext, cfg: &ConfigBag, ) -> Result; } @@ -40,7 +39,7 @@ pub enum RetryReason { pub trait ClassifyRetry: Send + Sync + Debug { /// Run this classifier against an error to determine if it should be retried. Returns /// `Some(RetryKind)` if the error should be retried; Otherwise returns `None`. - fn classify_retry(&self, error: &Error) -> Option; + fn classify_retry(&self, error: &OrchestratorError) -> Option; } #[derive(Debug)] @@ -68,7 +67,7 @@ impl RetryClassifiers { } impl ClassifyRetry for RetryClassifiers { - fn classify_retry(&self, error: &Error) -> Option { + fn classify_retry(&self, error: &OrchestratorError) -> Option { // return the first non-None result self.inner.iter().find_map(|cr| cr.classify_retry(error)) } diff --git a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs index f2269c1c4..e8043fc00 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs @@ -226,7 +226,8 @@ impl TypeErasedError { #[cfg(test)] mod tests { - use super::*; + use super::{TypeErasedError, TypedBox}; + use std::fmt; #[derive(Debug)] struct Foo(&'static str); @@ -234,7 +235,7 @@ mod tests { struct Bar(isize); #[test] - fn test() { + fn test_typed_boxes() { let foo = TypedBox::new(Foo("1")); let bar = TypedBox::new(Bar(2)); @@ -268,4 +269,33 @@ mod tests { let foo = *foo_erased.downcast::().expect("it's a Foo"); assert_eq!("4", foo.0); } + + #[derive(Debug, Clone, PartialEq, Eq)] + struct TestErr { + inner: &'static str, + } + + impl TestErr { + fn new(inner: &'static str) -> Self { + Self { inner } + } + } + + impl fmt::Display for TestErr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Error: {}", self.inner) + } + } + + impl std::error::Error for TestErr {} + + #[test] + fn test_typed_erased_errors_can_be_downcast() { + let test_err = TestErr::new("something failed!"); + let type_erased_test_err = TypeErasedError::new(test_err.clone()); + let actual = type_erased_test_err + .downcast::() + .expect("type erased error can be downcast into original type"); + assert_eq!(test_err, *actual); + } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/interceptor.rs b/rust-runtime/aws-smithy-runtime/src/client/interceptor.rs index 4e5de36b6..39ea3b35b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/interceptor.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/interceptor.rs @@ -4,8 +4,9 @@ */ use aws_smithy_http::body::SdkBody; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use std::fmt; use std::marker::PhantomData; @@ -37,7 +38,7 @@ where { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let request = context.request_mut(); @@ -69,7 +70,7 @@ where { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut BeforeTransmitInterceptorContextMut<'_>, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let request = context.request_mut(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index e3be9839c..72cd3cbef 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -8,10 +8,7 @@ use crate::client::orchestrator::endpoints::orchestrate_endpoint; use crate::client::orchestrator::http::read_body; use crate::client::timeout::{MaybeTimeout, ProvideMaybeTimeoutConfig, TimeoutKind}; use aws_smithy_http::result::SdkError; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeSerialization; -use aws_smithy_runtime_api::client::interceptors::context::{ - AttemptCheckpoint, Error, Input, Output, -}; +use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output}; use aws_smithy_runtime_api::client::interceptors::{InterceptorContext, Interceptors}; use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpResponse}; use aws_smithy_runtime_api::client::retries::ShouldAttempt; @@ -25,41 +22,24 @@ pub mod endpoints; mod http; pub mod interceptors; -#[doc(hidden)] -#[macro_export] -macro_rules! handle_err { - ([$checkpoint:expr] => $expr:expr) => { +macro_rules! halt_on_err { + ([$ctx:ident] => $expr:expr) => { match $expr { Ok(ok) => ok, Err(err) => { - return Err($checkpoint.into_error(err.into())); - } - } - }; - ($ctx:expr => $expr:expr) => { - match $expr { - Ok(ok) => ok, - Err(err) => { - use aws_smithy_runtime_api::client::interceptors::context::phase::Phase; - let (_input, output_or_error, _request, response, phase) = $ctx.into_parts(); - return Err(phase.convert_error(err.into(), output_or_error, response)); + $ctx.fail(err); + return; } } }; } -#[doc(hidden)] -#[macro_export] -macro_rules! bail { - ([$checkpoint:expr], $reason:expr) => {{ - return Err($checkpoint.into_error($reason.into())); - }}; - ($ctx:expr, $reason:expr) => {{ - use aws_smithy_runtime_api::client::interceptors::context::phase::Phase; - let reason: BoxError = $reason.into(); - let (_input, output_or_error, _request, response, phase) = $ctx.into_parts(); - return Err(phase.convert_error(reason, output_or_error, response)); - }}; +macro_rules! continue_on_err { + ([$ctx:ident] => $expr:expr) => { + if let Err(err) = $expr { + $ctx.fail(err); + } + }; } #[tracing::instrument(skip_all)] @@ -71,134 +51,139 @@ pub async fn invoke( let cfg = &mut cfg; let mut interceptors = Interceptors::new(); - let context = InterceptorContext::<()>::new(input); - - // Client configuration - handle_err!(context => runtime_plugins.apply_client_configuration(cfg, interceptors.client_interceptors_mut())); - handle_err!(context => interceptors.client_read_before_execution(&context, cfg)); - // Operation configuration - handle_err!(context => runtime_plugins.apply_operation_configuration(cfg, interceptors.operation_interceptors_mut())); - handle_err!(context => interceptors.operation_read_before_execution(&context, cfg)); + let mut ctx = InterceptorContext::new(input); + if let Err(err) = apply_configuration(&mut ctx, cfg, &mut interceptors, runtime_plugins) { + return Err(SdkError::construction_failure(err)); + } let operation_timeout_config = cfg.maybe_timeout_config(TimeoutKind::Operation); - invoke_post_config(cfg, context, interceptors) - .maybe_timeout_with_config(operation_timeout_config) - .await + async { + // If running the pre-execution interceptors failed, then we skip running the op and run the + // final interceptors instead. + if !ctx.is_failed() { + try_op(&mut ctx, cfg, &interceptors).await; + } + finally_op(&mut ctx, cfg, &interceptors).await; + ctx.finalize() + } + .maybe_timeout_with_config(operation_timeout_config) + .await } -async fn invoke_post_config( +/// Apply configuration is responsible for apply runtime plugins to the config bag, as well as running +/// `read_before_execution` interceptors. If a failure occurs due to config construction, `invoke` +/// will raise it to the user. If an interceptor fails, then `invoke` +fn apply_configuration( + ctx: &mut InterceptorContext, cfg: &mut ConfigBag, - mut before_serialization: InterceptorContext, - interceptors: Interceptors, -) -> Result> { + interceptors: &mut Interceptors, + runtime_plugins: &RuntimePlugins, +) -> Result<(), BoxError> { + runtime_plugins.apply_client_configuration(cfg, interceptors.client_interceptors_mut())?; + continue_on_err!([ctx] =>interceptors.client_read_before_execution(ctx, cfg).map_err(Into::into)); + runtime_plugins + .apply_operation_configuration(cfg, interceptors.operation_interceptors_mut())?; + continue_on_err!([ctx] => interceptors.operation_read_before_execution(ctx, cfg).map_err(Into::into)); + + Ok(()) +} + +async fn try_op(ctx: &mut InterceptorContext, cfg: &mut ConfigBag, interceptors: &Interceptors) { // Before serialization - handle_err!(before_serialization => interceptors.read_before_serialization(&before_serialization, cfg)); - handle_err!(before_serialization => interceptors.modify_before_serialization(&mut before_serialization, cfg)); + halt_on_err!([ctx] => interceptors.read_before_serialization(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.modify_before_serialization(ctx, cfg).map_err(Into::into)); // Serialization - let mut serialization = before_serialization.into_serialization_phase(); + ctx.enter_serialization_phase(); { let request_serializer = cfg.request_serializer(); - let request = handle_err!(serialization => request_serializer - .serialize_input(serialization.take_input().expect("input set at this point"))); - serialization.set_request(request); + let input = ctx.take_input().expect("input set at this point"); + let request = + halt_on_err!([ctx] => request_serializer.serialize_input(input).map_err(Into::into)); + ctx.set_request(request); } // Before transmit - let mut before_transmit = serialization.into_before_transmit_phase(); - handle_err!(before_transmit => interceptors.read_after_serialization(&before_transmit, cfg)); - handle_err!(before_transmit => interceptors.modify_before_retry_loop(&mut before_transmit, cfg)); - - { - let retry_strategy = cfg.retry_strategy(); - match retry_strategy.should_attempt_initial_request(cfg) { - // Yes, let's make a request - Ok(ShouldAttempt::Yes) => {} - // No, this request shouldn't be sent - Ok(ShouldAttempt::No) => { - bail!(before_transmit, "The retry strategy indicates that an initial request shouldn't be made, but it didn't specify why."); - } - // No, we shouldn't make a request because... - Err(err) => bail!(before_transmit, err), - Ok(ShouldAttempt::YesAfterDelay(_)) => { - unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.") - } + ctx.enter_before_transmit_phase(); + halt_on_err!([ctx] => interceptors.read_after_serialization(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.modify_before_retry_loop(ctx, cfg).map_err(Into::into)); + + let retry_strategy = cfg.retry_strategy(); + match retry_strategy.should_attempt_initial_request(cfg) { + // Yes, let's make a request + Ok(ShouldAttempt::Yes) => { /* Keep going */ } + // No, this request shouldn't be sent + Ok(ShouldAttempt::No) => { + let err: Box = "The retry strategy indicates that an initial request shouldn't be made, but it did specify why.".into(); + halt_on_err!([ctx] => Err(err.into())); + } + // No, we shouldn't make a request because... + Err(err) => halt_on_err!([ctx] => Err(err.into())), + Ok(ShouldAttempt::YesAfterDelay(_)) => { + unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.") } } - let mut checkpoint = AttemptCheckpoint::new(before_transmit); - checkpoint = loop { - if !checkpoint.rewind(cfg) { - break checkpoint; - } + loop { let attempt_timeout_config = cfg.maybe_timeout_config(TimeoutKind::OperationAttempt); - - checkpoint = make_an_attempt(checkpoint, cfg, &interceptors) - .maybe_timeout_with_config(attempt_timeout_config) - .await?; - handle_err!([checkpoint] => interceptors.read_after_attempt(checkpoint.after_deser(), cfg)); - handle_err!([checkpoint] => interceptors.modify_before_attempt_completion(checkpoint.after_deser(), cfg)); - + async { + try_attempt(ctx, cfg, interceptors).await; + finally_attempt(ctx, cfg, interceptors).await; + Result::<_, SdkError>::Ok(()) + } + .maybe_timeout_with_config(attempt_timeout_config) + .await + .expect("These are infallible; The retry strategy will decide whether to stop or not."); let retry_strategy = cfg.retry_strategy(); - match retry_strategy.should_attempt_retry(checkpoint.after_deser(), cfg) { + let should_attempt = halt_on_err!([ctx] => retry_strategy.should_attempt_retry(ctx, cfg).map_err(Into::into)); + match should_attempt { // Yes, let's retry the request - Ok(ShouldAttempt::Yes) => continue, + ShouldAttempt::Yes => continue, // No, this request shouldn't be retried - Ok(ShouldAttempt::No) => {} - Ok(ShouldAttempt::YesAfterDelay(_delay)) => { + ShouldAttempt::No => { + break; + } + ShouldAttempt::YesAfterDelay(_delay) => { // TODO(enableNewSmithyRuntime): implement retries with explicit delay todo!("implement retries with an explicit delay.") } - // I couldn't determine if the request should be retried because an error occurred. - Err(err) => bail!([checkpoint], err), } - - break checkpoint; - }; - - handle_err!([checkpoint] => interceptors.modify_before_completion(checkpoint.after_deser(), cfg)); - handle_err!([checkpoint] => interceptors.read_after_execution(checkpoint.after_deser(), cfg)); - - checkpoint.finalize() + } } -// Making an HTTP request can fail for several reasons, but we still need to -// call lifecycle events when that happens. Therefore, we define this -// `make_an_attempt` function to make error handling simpler. -#[tracing::instrument(skip_all)] -async fn make_an_attempt( - mut checkpoint: AttemptCheckpoint, +async fn try_attempt( + ctx: &mut InterceptorContext, cfg: &mut ConfigBag, interceptors: &Interceptors, -) -> Result> { - handle_err!([checkpoint] => interceptors.read_before_attempt(checkpoint.before_transmit(), cfg)); - handle_err!([checkpoint] => orchestrate_endpoint(checkpoint.before_transmit(), cfg)); - handle_err!([checkpoint] => interceptors.modify_before_signing(checkpoint.before_transmit(), cfg)); - handle_err!([checkpoint] => interceptors.read_before_signing(checkpoint.before_transmit(), cfg)); +) { + halt_on_err!([ctx] => interceptors.read_before_attempt(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => orchestrate_endpoint(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.modify_before_signing(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.read_before_signing(ctx, cfg).map_err(Into::into)); - checkpoint = orchestrate_auth(checkpoint, cfg).await?; + halt_on_err!([ctx] => orchestrate_auth(ctx, cfg).await.map_err(Into::into)); - handle_err!([checkpoint] => interceptors.read_after_signing(checkpoint.before_transmit(), cfg)); - handle_err!([checkpoint] => interceptors.modify_before_transmit(checkpoint.before_transmit(), cfg)); - handle_err!([checkpoint] => interceptors.read_before_transmit(checkpoint.before_transmit(), cfg)); + halt_on_err!([ctx] => interceptors.read_after_signing(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.modify_before_transmit(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.read_before_transmit(ctx, cfg).map_err(Into::into)); // The connection consumes the request but we need to keep a copy of it // within the interceptor context, so we clone it here. - checkpoint.transition_to_transmit(); - let call_result = handle_err!([checkpoint] => { - let request = checkpoint.transmit().take_request(); - cfg.connection().call(request).await + ctx.enter_transmit_phase(); + let call_result = halt_on_err!([ctx] => { + let request = ctx.take_request(); + cfg.connection().call(request).await.map_err(Into::into) }); - checkpoint.transmit().set_response(call_result); - checkpoint.transition_to_before_deserialization(); + ctx.set_response(call_result); + ctx.enter_before_deserialization_phase(); - handle_err!([checkpoint] => interceptors.read_after_transmit(checkpoint.before_deser(), cfg)); - handle_err!([checkpoint] => interceptors.modify_before_deserialization(checkpoint.before_deser(), cfg)); - handle_err!([checkpoint] => interceptors.read_before_deserialization(checkpoint.before_deser(), cfg)); + halt_on_err!([ctx] => interceptors.read_after_transmit(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.modify_before_deserialization(ctx, cfg).map_err(Into::into)); + halt_on_err!([ctx] => interceptors.read_before_deserialization(ctx, cfg).map_err(Into::into)); - checkpoint.transition_to_deserialization(); - let output_or_error = handle_err!([checkpoint] => { - let response = checkpoint.deser().response_mut(); + ctx.enter_deserialization_phase(); + let output_or_error = async { + let response = ctx.response_mut(); let response_deserializer = cfg.response_deserializer(); match response_deserializer.deserialize_streaming(response) { Some(output_or_error) => Ok(output_or_error), @@ -207,14 +192,31 @@ async fn make_an_attempt( .await .map(|_| response_deserializer.deserialize_nonstreaming(response)), } - }); + } + .await + .expect("how should I insert this into the context?"); + ctx.set_output_or_error(output_or_error); - checkpoint.deser().set_output_or_error(output_or_error); + ctx.enter_after_deserialization_phase(); + halt_on_err!([ctx] => interceptors.read_after_deserialization(ctx, cfg).map_err(Into::into)); +} - checkpoint.transition_to_after_deserialization(); - handle_err!([checkpoint] => interceptors.read_after_deserialization(checkpoint.after_deser(), cfg)); +async fn finally_attempt( + ctx: &mut InterceptorContext, + cfg: &mut ConfigBag, + interceptors: &Interceptors, +) { + continue_on_err!([ctx] => interceptors.modify_before_attempt_completion(ctx, cfg).map_err(Into::into)); + continue_on_err!([ctx] => interceptors.read_after_attempt(ctx, cfg).map_err(Into::into)); +} - Ok(checkpoint) +async fn finally_op( + ctx: &mut InterceptorContext, + cfg: &mut ConfigBag, + interceptors: &Interceptors, +) { + continue_on_err!([ctx] => interceptors.modify_before_completion(ctx, cfg).map_err(Into::into)); + continue_on_err!([ctx] => interceptors.read_after_execution(ctx, cfg).map_err(Into::into)); } #[cfg(all(test, feature = "test-util", feature = "anonymous-auth"))] @@ -230,14 +232,20 @@ mod tests { serializer::CannedRequestSerializer, }; use aws_smithy_http::body::SdkBody; - use aws_smithy_runtime_api::client::interceptors::context::phase::{ - AfterDeserialization, BeforeDeserialization, BeforeSerialization, BeforeTransmit, + use aws_smithy_runtime_api::client::interceptors::context::wrappers::{ + FinalizerInterceptorContextMut, FinalizerInterceptorContextRef, + }; + use aws_smithy_runtime_api::client::interceptors::context::Output; + use aws_smithy_runtime_api::client::interceptors::{ + AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut, + BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut, + BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, + BeforeTransmitInterceptorContextRef, }; - use aws_smithy_runtime_api::client::interceptors::context::{Error, Output}; use aws_smithy_runtime_api::client::interceptors::{ - Interceptor, InterceptorContext, InterceptorRegistrar, SharedInterceptor, + Interceptor, InterceptorRegistrar, SharedInterceptor, }; - use aws_smithy_runtime_api::client::orchestrator::ConfigBagAccessors; + use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError}; use aws_smithy_runtime_api::client::runtime_plugin::{BoxError, RuntimePlugin, RuntimePlugins}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_runtime_api::type_erasure::TypeErasedBox; @@ -257,7 +265,7 @@ mod tests { http::Response::builder() .status(StatusCode::OK) .body(SdkBody::empty()) - .map_err(|err| Error::new(Box::new(err))) + .map_err(|err| OrchestratorError::other(Box::new(err))) .map(|res| Output::new(Box::new(res))), ) } @@ -311,6 +319,21 @@ mod tests { } } + #[derive(Debug)] + struct FailingInterceptorsClientRuntimePlugin; + + impl RuntimePlugin for FailingInterceptorsClientRuntimePlugin { + fn configure( + &self, + _cfg: &mut ConfigBag, + interceptors: &mut InterceptorRegistrar, + ) -> Result<(), BoxError> { + interceptors.register(SharedInterceptor::new(FailingInterceptorA)); + + Ok(()) + } + } + #[derive(Debug)] struct FailingInterceptorsOperationRuntimePlugin; @@ -320,7 +343,6 @@ mod tests { _cfg: &mut ConfigBag, interceptors: &mut InterceptorRegistrar, ) -> Result<(), BoxError> { - interceptors.register(SharedInterceptor::new(FailingInterceptorA)); interceptors.register(SharedInterceptor::new(FailingInterceptorB)); interceptors.register(SharedInterceptor::new(FailingInterceptorC)); @@ -330,6 +352,7 @@ mod tests { let input = TypeErasedBox::new(Box::new(())); let runtime_plugins = RuntimePlugins::new() + .with_client_plugin(FailingInterceptorsClientRuntimePlugin) .with_operation_plugin(TestOperationRuntimePlugin) .with_operation_plugin(AnonymousAuthRuntimePlugin) .with_operation_plugin(FailingInterceptorsOperationRuntimePlugin); @@ -351,7 +374,7 @@ mod tests { let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ReadBeforeExecution, source: Some(\"FailingInterceptorC\") } })""#.to_string(); interceptor_error_handling_test!( read_before_execution, - &InterceptorContext, + &BeforeSerializationInterceptorContextRef<'_>, expected ); } @@ -362,7 +385,7 @@ mod tests { let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ModifyBeforeSerialization, source: Some(\"FailingInterceptorC\") } })""#.to_string(); interceptor_error_handling_test!( modify_before_serialization, - &mut InterceptorContext, + &mut BeforeSerializationInterceptorContextMut<'_>, expected ); } @@ -373,7 +396,7 @@ mod tests { let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ReadBeforeSerialization, source: Some(\"FailingInterceptorC\") } })""#.to_string(); interceptor_error_handling_test!( read_before_serialization, - &InterceptorContext, + &BeforeSerializationInterceptorContextRef<'_>, expected ); } @@ -384,7 +407,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ReadAfterSerialization, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( read_after_serialization, - &InterceptorContext, + &BeforeTransmitInterceptorContextRef<'_>, expected ); } @@ -395,7 +418,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeRetryLoop, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( modify_before_retry_loop, - &mut InterceptorContext, + &mut BeforeTransmitInterceptorContextMut<'_>, expected ); } @@ -406,7 +429,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ReadBeforeAttempt, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( read_before_attempt, - &InterceptorContext, + &BeforeTransmitInterceptorContextRef<'_>, expected ); } @@ -417,7 +440,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeSigning, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( modify_before_signing, - &mut InterceptorContext, + &mut BeforeTransmitInterceptorContextMut<'_>, expected ); } @@ -428,7 +451,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ReadBeforeSigning, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( read_before_signing, - &InterceptorContext, + &BeforeTransmitInterceptorContextRef<'_>, expected ); } @@ -439,7 +462,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ReadAfterSigning, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( read_after_signing, - &InterceptorContext, + &BeforeTransmitInterceptorContextRef<'_>, expected ); } @@ -450,7 +473,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeTransmit, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( modify_before_transmit, - &mut InterceptorContext, + &mut BeforeTransmitInterceptorContextMut<'_>, expected ); } @@ -461,7 +484,7 @@ mod tests { let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ReadBeforeTransmit, source: Some(\"FailingInterceptorC\") }, connection: Unknown } })""#.to_string(); interceptor_error_handling_test!( read_before_transmit, - &InterceptorContext, + &BeforeTransmitInterceptorContextRef<'_>, expected ); } @@ -472,7 +495,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterTransmit, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( read_after_transmit, - &InterceptorContext, + &BeforeDeserializationInterceptorContextRef<'_>, expected ); } @@ -483,7 +506,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeDeserialization, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( modify_before_deserialization, - &mut InterceptorContext, + &mut BeforeDeserializationInterceptorContextMut<'_>, expected ); } @@ -494,7 +517,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadBeforeDeserialization, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( read_before_deserialization, - &InterceptorContext, + &BeforeDeserializationInterceptorContextRef<'_>, expected ); } @@ -505,7 +528,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterDeserialization, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( read_after_deserialization, - &InterceptorContext, + &AfterDeserializationInterceptorContextRef<'_>, expected ); } @@ -516,7 +539,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( modify_before_attempt_completion, - &mut InterceptorContext, + &mut FinalizerInterceptorContextMut<'_>, expected ); } @@ -527,7 +550,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterAttempt, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( read_after_attempt, - &InterceptorContext, + &FinalizerInterceptorContextRef<'_>, expected ); } @@ -538,7 +561,7 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( modify_before_completion, - &mut InterceptorContext, + &mut FinalizerInterceptorContextMut<'_>, expected ); } @@ -549,7 +572,297 @@ mod tests { let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterExecution, source: Some(\"FailingInterceptorC\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); interceptor_error_handling_test!( read_after_execution, - &InterceptorContext, + &FinalizerInterceptorContextRef<'_>, + expected + ); + } + + macro_rules! interceptor_error_redirection_test { + ($origin_interceptor:ident, $origin_ctx:ty, $destination_interceptor:ident, $destination_ctx:ty, $expected:expr) => { + #[derive(Debug)] + struct OriginInterceptor; + impl Interceptor for OriginInterceptor { + fn $origin_interceptor( + &self, + _ctx: $origin_ctx, + _cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + tracing::debug!("OriginInterceptor called!"); + Err("OriginInterceptor".into()) + } + } + + #[derive(Debug)] + struct DestinationInterceptor; + impl Interceptor for DestinationInterceptor { + fn $destination_interceptor( + &self, + _ctx: $destination_ctx, + _cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + tracing::debug!("DestinationInterceptor called!"); + Err("DestinationInterceptor".into()) + } + } + + #[derive(Debug)] + struct InterceptorsTestOperationRuntimePlugin; + + impl RuntimePlugin for InterceptorsTestOperationRuntimePlugin { + fn configure( + &self, + _cfg: &mut ConfigBag, + interceptors: &mut InterceptorRegistrar, + ) -> Result<(), BoxError> { + interceptors.register(SharedInterceptor::new(OriginInterceptor)); + interceptors.register(SharedInterceptor::new(DestinationInterceptor)); + + Ok(()) + } + } + + let input = TypeErasedBox::new(Box::new(())); + let runtime_plugins = RuntimePlugins::new() + .with_operation_plugin(TestOperationRuntimePlugin) + .with_operation_plugin(AnonymousAuthRuntimePlugin) + .with_operation_plugin(InterceptorsTestOperationRuntimePlugin); + let actual = invoke(input, &runtime_plugins) + .await + .expect_err("should error"); + let actual = format!("{:?}", actual); + assert_eq!($expected, format!("{:?}", actual)); + + assert!(logs_contain("OriginInterceptor called!")); + assert!(logs_contain("DestinationInterceptor called!")); + }; + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_execution_error_causes_jump_to_modify_before_completion() { + let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"DestinationInterceptor\") } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_execution, + &BeforeSerializationInterceptorContextRef<'_>, + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_serialization_error_causes_jump_to_modify_before_completion() { + let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"DestinationInterceptor\") } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_serialization, + &mut BeforeSerializationInterceptorContextMut<'_>, + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_serialization_error_causes_jump_to_modify_before_completion() { + let expected = r#""ConstructionFailure(ConstructionFailure { source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"DestinationInterceptor\") } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_serialization, + &BeforeSerializationInterceptorContextRef<'_>, + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_after_serialization_error_causes_jump_to_modify_before_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + read_after_serialization, + &BeforeTransmitInterceptorContextRef<'_>, + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_retry_loop_error_causes_jump_to_modify_before_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_retry_loop, + &mut BeforeTransmitInterceptorContextMut<'_>, + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_attempt_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_attempt, + &BeforeTransmitInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_signing_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_signing, + &mut BeforeTransmitInterceptorContextMut<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_signing_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_signing, + &BeforeTransmitInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_after_signing_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + read_after_signing, + &BeforeTransmitInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_transmit_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_transmit, + &mut BeforeTransmitInterceptorContextMut<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_transmit_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""DispatchFailure(DispatchFailure { source: ConnectorError { kind: Other(None), source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, connection: Unknown } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_transmit, + &BeforeTransmitInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_after_transmit_error_causes_jump_to_modify_before_attempt_completion() { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + read_after_transmit, + &BeforeDeserializationInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_deserialization_error_causes_jump_to_modify_before_attempt_completion( + ) { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_deserialization, + &mut BeforeDeserializationInterceptorContextMut<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_before_deserialization_error_causes_jump_to_modify_before_attempt_completion( + ) { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(None), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + read_before_deserialization, + &BeforeDeserializationInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_read_after_deserialization_error_causes_jump_to_modify_before_attempt_completion() + { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ModifyBeforeAttemptCompletion, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + read_after_deserialization, + &AfterDeserializationInterceptorContextRef<'_>, + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + expected + ); + } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_attempt_completion_error_causes_jump_to_read_after_attempt() { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterAttempt, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_attempt_completion, + &mut FinalizerInterceptorContextMut<'_>, + read_after_attempt, + &FinalizerInterceptorContextRef<'_>, + expected + ); + } + + // #[tokio::test] + // #[traced_test] + // async fn test_read_after_attempt_error_causes_jump_to_modify_before_attempt_completion() { + // todo!("I'm confused by the behavior described in the spec") + // } + + #[tokio::test] + #[traced_test] + async fn test_modify_before_completion_error_causes_jump_to_read_after_execution() { + let expected = r#""ResponseError(ResponseError { source: InterceptorError { kind: ReadAfterExecution, source: Some(\"DestinationInterceptor\") }, raw: Response { status: 200, version: HTTP/1.1, headers: {}, body: SdkBody { inner: Once(Some(b\"\")), retryable: true } } })""#.to_string(); + interceptor_error_redirection_test!( + modify_before_completion, + &mut FinalizerInterceptorContextMut<'_>, + read_after_execution, + &FinalizerInterceptorContextRef<'_>, expected ); } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs index a3b053837..883e628ad 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs @@ -3,25 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::{bail, handle_err}; -use aws_smithy_http::result::SdkError; -use aws_smithy_runtime_api::client::interceptors::context::{AttemptCheckpoint, Error}; -use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpResponse}; +use aws_smithy_runtime_api::client::interceptors::InterceptorContext; +use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors}; use aws_smithy_runtime_api::config_bag::ConfigBag; +use std::fmt; pub(super) async fn orchestrate_auth( - mut checkpoint: AttemptCheckpoint, + ctx: &mut InterceptorContext, cfg: &ConfigBag, -) -> Result> { - fn construction_failure(err: impl Into) -> SdkError { - SdkError::construction_failure(err) - } - +) -> Result<(), BoxError> { let params = cfg.auth_option_resolver_params(); - let auth_options = cfg - .auth_option_resolver() - .resolve_auth_options(params) - .map_err(construction_failure)?; + let auth_options = cfg.auth_option_resolver().resolve_auth_options(params)?; let identity_resolvers = cfg.identity_resolvers(); tracing::trace!( @@ -35,23 +27,31 @@ pub(super) async fn orchestrate_auth( if let Some(identity_resolver) = auth_scheme.identity_resolver(identity_resolvers) { let request_signer = auth_scheme.request_signer(); - let identity = identity_resolver - .resolve_identity(cfg) - .await - .map_err(construction_failure)?; - let request = checkpoint.before_transmit().request_mut(); - handle_err!([checkpoint] => request_signer.sign_request(request, &identity, cfg)); - return Ok(checkpoint); + let identity = identity_resolver.resolve_identity(cfg).await?; + let request = ctx.request_mut(); + request_signer.sign_request(request, &identity, cfg)?; + return Ok(()); } } } - bail!( - [checkpoint], - "no auth scheme matched auth options. This is a bug. Please file an issue." - ); + Err(NoMatchingAuthScheme.into()) +} + +#[derive(Debug)] +struct NoMatchingAuthScheme; + +impl fmt::Display for NoMatchingAuthScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "no auth scheme matched auth options. This is a bug. Please file an issue." + ) + } } +impl std::error::Error for NoMatchingAuthScheme {} + #[cfg(test)] mod tests { use super::*; @@ -115,11 +115,11 @@ mod tests { } } - let input = TypedBox::new("doesnt-matter").erase(); - let mut context = InterceptorContext::<()>::new(input).into_serialization_phase(); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); - let _ = context.take_input(); - let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); + let mut ctx = InterceptorContext::new(TypedBox::new("doesnt-matter").erase()); + ctx.enter_serialization_phase(); + ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = ctx.take_input(); + ctx.enter_before_transmit_phase(); let mut cfg = ConfigBag::base(); cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter")); @@ -135,16 +135,11 @@ mod tests { .build(), ); - let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); + orchestrate_auth(&mut ctx, &cfg).await.expect("success"); assert_eq!( "success!", - checkpoint - .before_transmit() - .request() - .headers() - .get("Authorization") - .unwrap() + ctx.request().headers().get("Authorization").unwrap() ); } @@ -157,11 +152,11 @@ mod tests { }; use aws_smithy_runtime_api::client::identity::http::{Login, Token}; - let mut context = InterceptorContext::<()>::new(TypedBox::new("doesnt-matter").erase()) - .into_serialization_phase(); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); - let _ = context.take_input(); - let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); + let mut ctx = InterceptorContext::new(TypedBox::new("doesnt-matter").erase()); + ctx.enter_serialization_phase(); + ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = ctx.take_input(); + ctx.enter_before_transmit_phase(); let mut cfg = ConfigBag::base(); cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter")); @@ -183,17 +178,11 @@ mod tests { .build(), ); - let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); - + orchestrate_auth(&mut ctx, &cfg).await.expect("success"); assert_eq!( // "YTpi" == "a:b" in base64 "Basic YTpi", - checkpoint - .before_transmit() - .request() - .headers() - .get("Authorization") - .unwrap() + ctx.request().headers().get("Authorization").unwrap() ); // Next, test the presence of a bearer token and absence of basic auth @@ -203,21 +192,15 @@ mod tests { .build(), ); - let mut context = InterceptorContext::<()>::new(TypedBox::new("doesnt-matter").erase()) - .into_serialization_phase(); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); - let _ = context.take_input(); - let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); - - let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); + let mut ctx = InterceptorContext::new(TypedBox::new("doesnt-matter").erase()); + ctx.enter_serialization_phase(); + ctx.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = ctx.take_input(); + ctx.enter_before_transmit_phase(); + orchestrate_auth(&mut ctx, &cfg).await.expect("success"); assert_eq!( "Bearer t", - checkpoint - .before_transmit() - .request() - .headers() - .get("Authorization") - .unwrap() + ctx.request().headers().get("Authorization").unwrap() ); } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs index 8761bc5d6..cb4e4e9c0 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs @@ -7,7 +7,6 @@ use aws_smithy_http::endpoint::error::ResolveEndpointError; use aws_smithy_http::endpoint::{ apply_endpoint, EndpointPrefix, ResolveEndpoint, SharedEndpointResolver, }; -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::{ BoxError, ConfigBagAccessors, EndpointResolver, EndpointResolverParams, HttpRequest, @@ -130,7 +129,7 @@ where } pub(super) fn orchestrate_endpoint( - ctx: &mut InterceptorContext, + ctx: &mut InterceptorContext, cfg: &ConfigBag, ) -> Result<(), BoxError> { let params = cfg.endpoint_resolver_params(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/request_attempts.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/request_attempts.rs index c69137b02..6d48abee6 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/request_attempts.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/request_attempts.rs @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeTransmitInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; #[derive(Debug, Clone, Default)] @@ -48,7 +49,7 @@ impl RequestAttemptsInterceptor { impl Interceptor for RequestAttemptsInterceptor { fn modify_before_retry_loop( &self, - _ctx: &mut InterceptorContext, + _ctx: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { cfg.put(RequestAttempts::new()); @@ -57,7 +58,7 @@ impl Interceptor for RequestAttemptsInterceptor { fn modify_before_transmit( &self, - _ctx: &mut InterceptorContext, + _ctx: &mut BeforeTransmitInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { if let Some(request_attempts) = cfg.get::().cloned() { diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/service_clock_skew.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/service_clock_skew.rs index 579ecbb48..eb988c857 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/service_clock_skew.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/interceptors/service_clock_skew.rs @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeDeserialization; -use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; +use aws_smithy_runtime_api::client::interceptors::{ + BeforeDeserializationInterceptorContextMut, BoxError, Interceptor, +}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_types::date_time::Format; use aws_smithy_types::DateTime; @@ -48,7 +49,7 @@ fn calculate_skew(time_sent: DateTime, time_received: DateTime) -> Duration { } fn extract_time_sent_from_response( - ctx: &mut InterceptorContext, + ctx: &mut BeforeDeserializationInterceptorContextMut<'_>, ) -> Result { let date_header = ctx .response() @@ -62,7 +63,7 @@ fn extract_time_sent_from_response( impl Interceptor for ServiceClockSkewInterceptor { fn modify_before_deserialization( &self, - ctx: &mut InterceptorContext, + ctx: &mut BeforeDeserializationInterceptorContextMut<'_>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let time_received = DateTime::from(SystemTime::now()); diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs index 1ee985c22..d3f4abe56 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs @@ -4,7 +4,6 @@ */ use crate::client::orchestrator::interceptors::RequestAttempts; -use aws_smithy_runtime_api::client::interceptors::context::phase::AfterDeserialization; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::BoxError; use aws_smithy_runtime_api::client::retries::{ @@ -41,7 +40,7 @@ impl RetryStrategy for FixedDelayRetryStrategy { fn should_attempt_retry( &self, - ctx: &InterceptorContext, + ctx: &InterceptorContext, cfg: &ConfigBag, ) -> Result { // Look a the result. If it's OK then we're done; No retry required. Otherwise, we need to inspect it diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs index 3b1b56ffb..f415af369 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_runtime_api::client::interceptors::context::phase::AfterDeserialization; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::BoxError; use aws_smithy_runtime_api::client::retries::{RetryStrategy, ShouldAttempt}; @@ -25,7 +24,7 @@ impl RetryStrategy for NeverRetryStrategy { fn should_attempt_retry( &self, - _context: &InterceptorContext, + _context: &InterceptorContext, _cfg: &ConfigBag, ) -> Result { Ok(ShouldAttempt::No) diff --git a/rust-runtime/aws-smithy-runtime/src/client/test_util/deserializer.rs b/rust-runtime/aws-smithy-runtime/src/client/test_util/deserializer.rs index 3c7bc1610..7df0e2fe7 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/test_util/deserializer.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/test_util/deserializer.rs @@ -6,7 +6,7 @@ use aws_smithy_runtime_api::client::interceptors::context::{Error, Output}; use aws_smithy_runtime_api::client::interceptors::InterceptorRegistrar; use aws_smithy_runtime_api::client::orchestrator::{ - ConfigBagAccessors, HttpResponse, ResponseDeserializer, + ConfigBagAccessors, HttpResponse, OrchestratorError, ResponseDeserializer, }; use aws_smithy_runtime_api::client::runtime_plugin::{BoxError, RuntimePlugin}; use aws_smithy_runtime_api::config_bag::ConfigBag; @@ -14,17 +14,17 @@ use std::sync::Mutex; #[derive(Default, Debug)] pub struct CannedResponseDeserializer { - inner: Mutex>>, + inner: Mutex>>>, } impl CannedResponseDeserializer { - pub fn new(output: Result) -> Self { + pub fn new(output: Result>) -> Self { Self { inner: Mutex::new(Some(output)), } } - pub fn take(&self) -> Option> { + pub fn take(&self) -> Option>> { match self.inner.lock() { Ok(mut guard) => guard.take(), Err(_) => None, @@ -33,7 +33,10 @@ impl CannedResponseDeserializer { } impl ResponseDeserializer for CannedResponseDeserializer { - fn deserialize_nonstreaming(&self, _response: &HttpResponse) -> Result { + fn deserialize_nonstreaming( + &self, + _response: &HttpResponse, + ) -> Result> { self.take() .ok_or("CannedResponseDeserializer's inner value has already been taken.") .unwrap() -- GitLab