diff --git a/aws/rust-runtime/aws-runtime/src/invocation_id.rs b/aws/rust-runtime/aws-runtime/src/invocation_id.rs index 89bcbb8d5bc3d9c5902d4a74c39c0d415aed07df..98d5405ecbdbf4e99f625215001e4c7aae3f1a45 100644 --- a/aws/rust-runtime/aws-runtime/src/invocation_id.rs +++ b/aws/rust-runtime/aws-runtime/src/invocation_id.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::error::BoxError; use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; @@ -37,10 +38,10 @@ impl Default for InvocationIdInterceptor { impl Interceptor for InvocationIdInterceptor { fn modify_before_retry_loop( &self, - context: &mut InterceptorContext, + context: &mut InterceptorContext, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - let headers = context.request_mut()?.headers_mut(); + let headers = context.request_mut().headers_mut(); let id = _cfg.get::().unwrap_or(&self.id); headers.append(AMZ_SDK_INVOCATION_ID, id.0.clone()); Ok(()) @@ -72,24 +73,26 @@ impl InvocationId { mod tests { use crate::invocation_id::InvocationIdInterceptor; use aws_smithy_http::body::SdkBody; + use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_smithy_runtime_api::type_erasure::TypedBox; use http::HeaderValue; - fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a HeaderValue { - context - .request() - .unwrap() - .headers() - .get(header_name) - .unwrap() + fn expect_header<'a>( + context: &'a InterceptorContext, + header_name: &str, + ) -> &'a HeaderValue { + context.request().headers().get(header_name).unwrap() } #[test] fn test_id_is_generated_and_set() { - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) + .into_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = context.take_input(); + let mut context = context.into_before_transmit_phase(); let mut config = ConfigBag::base(); let interceptor = InvocationIdInterceptor::new(); diff --git a/aws/rust-runtime/aws-runtime/src/recursion_detection.rs b/aws/rust-runtime/aws-runtime/src/recursion_detection.rs index deafdd973f7bfb5893c25b5a9c3bddff628f7c63..38d2cf2725026e0b31051ef3b2ee2f4b2c3e40fc 100644 --- a/aws/rust-runtime/aws-runtime/src/recursion_detection.rs +++ b/aws/rust-runtime/aws-runtime/src/recursion_detection.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; use aws_types::os_shim_internal::Env; @@ -39,10 +40,10 @@ impl RecursionDetectionInterceptor { impl Interceptor for RecursionDetectionInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut InterceptorContext, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - let request = context.request_mut()?; + let request = context.request_mut(); if request.headers().contains_key(TRACE_ID_HEADER) { return Ok(()); } @@ -145,14 +146,17 @@ mod tests { request = request.header(name, value); } let request = request.body(SdkBody::empty()).expect("must be valid"); - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); + let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) + .into_serialization_phase(); context.set_request(request); + let _ = context.take_input(); + let mut context = context.into_before_transmit_phase(); let mut config = ConfigBag::base(); RecursionDetectionInterceptor { env } .modify_before_signing(&mut context, &mut config) .expect("interceptor must succeed"); - let mutated_request = context.request().expect("request is still set"); + let mutated_request = context.request(); for name in mutated_request.headers().keys() { assert_eq!( mutated_request.headers().get_all(name).iter().count(), diff --git a/aws/rust-runtime/aws-runtime/src/user_agent.rs b/aws/rust-runtime/aws-runtime/src/user_agent.rs index a6c6ff1e8507938d05af4ae0d9f580e026f08d81..cdd3003da63bf49ceabf76d6b2a8edbf52116f86 100644 --- a/aws/rust-runtime/aws-runtime/src/user_agent.rs +++ b/aws/rust-runtime/aws-runtime/src/user_agent.rs @@ -4,6 +4,7 @@ */ use aws_http::user_agent::{ApiMetadata, AwsUserAgent}; +use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::error::BoxError; use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext}; use aws_smithy_runtime_api::config_bag::ConfigBag; @@ -72,7 +73,7 @@ fn header_values( impl Interceptor for UserAgentInterceptor { fn modify_before_signing( &self, - context: &mut InterceptorContext, + context: &mut InterceptorContext, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let api_metadata = cfg @@ -95,7 +96,7 @@ impl Interceptor for UserAgentInterceptor { Cow::Owned(ua) }); - let headers = context.request_mut()?.headers_mut(); + let headers = context.request_mut().headers_mut(); let (user_agent, x_amz_user_agent) = header_values(&ua)?; headers.append(USER_AGENT, user_agent); headers.append(X_AMZ_USER_AGENT, x_amz_user_agent); @@ -112,10 +113,12 @@ mod tests { use aws_smithy_runtime_api::type_erasure::TypedBox; use aws_smithy_types::error::display::DisplayErrorContext; - fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str { + fn expect_header<'a>( + context: &'a InterceptorContext, + header_name: &str, + ) -> &'a str { context .request() - .unwrap() .headers() .get(header_name) .unwrap() @@ -123,10 +126,17 @@ mod tests { .unwrap() } + fn context() -> InterceptorContext { + let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase()) + .into_serialization_phase(); + context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = context.take_input(); + context.into_before_transmit_phase() + } + #[test] fn test_overridden_ua() { - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let mut context = context(); let mut config = ConfigBag::base(); config.put(AwsUserAgent::for_tests()); @@ -149,8 +159,7 @@ mod tests { #[test] fn test_default_ua() { - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let mut context = context(); let api_metadata = ApiMetadata::new("some-service", "some-version"); let mut config = ConfigBag::base(); @@ -178,8 +187,7 @@ mod tests { #[test] fn test_app_name() { - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let mut context = context(); let api_metadata = ApiMetadata::new("some-service", "some-version"); let mut config = ConfigBag::base(); @@ -207,9 +215,7 @@ mod tests { #[test] fn test_api_metadata_missing() { - let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase()); - context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); - + let mut context = context(); let mut config = ConfigBag::base(); let interceptor = UserAgentInterceptor::new(); diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt index ef730207a73085607434fce63fabf05a7d30064e..f3f1da9720d24f28ed2797b18289bbe2207d2f70 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RetryClassifierDecorator.kt @@ -74,7 +74,7 @@ class OperationRetryClassifiersFeature( "RetryClassifiers" to smithyRuntimeApi.resolve("client::retries::RetryClassifiers"), "OperationError" to codegenContext.symbolProvider.symbolForOperationError(operation), "SdkError" to RuntimeType.smithyHttp(runtimeConfig).resolve("result::SdkError"), - "ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedBox"), + "ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedError"), ) override fun section(section: OperationRuntimePluginSection) = when (section) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt index 413785734677e950f1ee04ddbe82fde40c539648..ee1975d65d33a9aef3ddeea93299828da44c1f7f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsInterceptorGenerator.kt @@ -49,6 +49,7 @@ class EndpointParamsInterceptorGenerator( "HttpResponse" to orchestrator.resolve("HttpResponse"), "Interceptor" to interceptors.resolve("Interceptor"), "InterceptorContext" to interceptors.resolve("InterceptorContext"), + "BeforeSerializationPhase" to interceptors.resolve("context::phase::BeforeSerialization"), "InterceptorError" to interceptors.resolve("error::InterceptorError"), "Params" to endpointTypesGenerator.paramsStruct(), ) @@ -66,11 +67,10 @@ class EndpointParamsInterceptorGenerator( impl #{Interceptor} for $interceptorName { fn read_before_execution( &self, - context: &#{InterceptorContext}, + context: &#{InterceptorContext}<#{BeforeSerializationPhase}>, cfg: &mut #{ConfigBag}, ) -> Result<(), #{BoxError}> { - let _input = context.input()?; - let _input = _input + let _input = context.input() .downcast_ref::<${operationInput.name}>() .ok_or("failed to downcast to ${operationInput.name}")?; diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt index f35fe7446c1b5f253273210dfca722e08c74fa26..7714850f7107c03711945c4b52d725a35498dff9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -84,7 +84,6 @@ class ServiceRuntimePluginGenerator( "Interceptors" to runtimeApi.resolve("client::interceptors::Interceptors"), "SharedEndpointResolver" to http.resolve("endpoint::SharedEndpointResolver"), "StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"), - "TraceProbe" to runtimeApi.resolve("client::orchestrator::TraceProbe"), ) } @@ -131,18 +130,6 @@ class ServiceRuntimePluginGenerator( .expect("connection set"); cfg.set_connection(connection); - // TODO(RuntimePlugins): Add the TraceProbe to the config bag - cfg.set_trace_probe({ - ##[derive(Debug)] - struct StubTraceProbe; - impl #{TraceProbe} for StubTraceProbe { - fn dispatch_events(&self) { - // no-op - } - } - StubTraceProbe - }); - #{additional_config} Ok(()) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index f10de15ee457c04a3defe6663a0538d1310c07c3..c4714038f02fe8b77b41532d514df85420d2021d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -388,7 +388,7 @@ class FluentClientGenerator( .await .map_err(|err| { err.map_service_error(|err| { - #{TypedBox}::<#{OperationError}>::assume_from(err) + #{TypedBox}::<#{OperationError}>::assume_from(err.into()) .expect("correct error type") .unwrap() }) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt index 4b355cf8badd9f5154e952917313ff4c966c778f..654463df74a8f3ca3fd006c153310589210d71fe 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt @@ -157,10 +157,10 @@ class ResponseDeserializerGenerator( pub(crate) fn $fnName(result: Result) -> Result<#{Output}, #{Error}> where O: std::fmt::Debug + Send + Sync + 'static, - E: std::fmt::Debug + Send + Sync + 'static, + E: std::error::Error + std::fmt::Debug + Send + Sync + 'static, { result.map(|output| #{TypedBox}::new(output).erase()) - .map_err(|error| #{TypedBox}::new(error).erase()) + .map_err(|error| #{TypedBox}::new(error).erase_error()) } """, *codegenScope, diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs index 20a34f6a765ebfb153b86ac0600b07606e0c0ee1..f025cb5a70913a18a20b3f351a4149fd98bf88b1 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs @@ -6,6 +6,9 @@ pub mod context; pub mod error; +use crate::client::interceptors::context::phase::{ + AfterDeserialization, BeforeDeserialization, BeforeSerialization, BeforeTransmit, +}; use crate::config_bag::ConfigBag; use aws_smithy_types::error::display::DisplayErrorContext; pub use context::InterceptorContext; @@ -13,19 +16,23 @@ pub use error::{BoxError, InterceptorError}; use std::sync::Arc; macro_rules! interceptor_trait_fn { - ($name:ident, $docs:tt) => { + ($name:ident, $phase:ident, $docs:tt) => { #[doc = $docs] - fn $name(&self, context: &InterceptorContext, cfg: &mut ConfigBag) -> Result<(), BoxError> { + fn $name( + &self, + context: &InterceptorContext<$phase>, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { let _ctx = context; let _cfg = cfg; Ok(()) } }; - (mut $name:ident, $docs:tt) => { + (mut $name:ident, $phase:ident, $docs:tt) => { #[doc = $docs] fn $name( &self, - context: &mut InterceptorContext, + context: &mut InterceptorContext<$phase>, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { let _ctx = context; @@ -48,6 +55,7 @@ macro_rules! interceptor_trait_fn { pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_execution, + BeforeSerialization, " A hook called at the start of an execution, before the SDK does anything else. @@ -71,6 +79,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_serialization, + BeforeSerialization, " A hook called before the input message is marshalled into a transport message. @@ -98,6 +107,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_serialization, + BeforeSerialization, " A hook called before the input message is marshalled into a transport @@ -119,6 +129,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_serialization, + BeforeTransmit, " /// A hook called after the input message is marshalled into /// a transport message. @@ -140,6 +151,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_retry_loop, + BeforeTransmit, " A hook called before the retry loop is entered. This method has the ability to modify and return a new transport request @@ -161,6 +173,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_attempt, + BeforeTransmit, " A hook called before each attempt at sending the transmission request message to the service. @@ -187,6 +200,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_signing, + BeforeTransmit, " A hook called before the transport request message is signed. This method has the ability to modify and return a new transport @@ -218,6 +232,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_signing, + BeforeTransmit, " A hook called before the transport request message is signed. @@ -241,6 +256,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_signing, + BeforeTransmit, " A hook called after the transport request message is signed. @@ -264,6 +280,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_transmit, + BeforeTransmit, " /// A hook called before the transport request message is sent to the /// service. This method has the ability to modify and return @@ -295,6 +312,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_transmit, + BeforeTransmit, " A hook called before the transport request message is sent to the service. @@ -322,6 +340,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_transmit, + BeforeDeserialization, " A hook called after the transport request message is sent to the service and a transport response message is received. @@ -349,6 +368,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_deserialization, + BeforeDeserialization, " A hook called before the transport response message is unmarshalled. This method has the ability to modify and return a new transport @@ -380,6 +400,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_before_deserialization, + BeforeDeserialization, " A hook called before the transport response message is unmarshalled @@ -406,6 +427,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_deserialization, + AfterDeserialization, " A hook called after the transport response message is unmarshalled. @@ -432,6 +454,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_attempt_completion, + AfterDeserialization, " A hook called when an attempt is completed. This method has the ability to modify and return a new output message or error @@ -460,6 +483,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_attempt, + AfterDeserialization, " A hook called when an attempt is completed. @@ -488,6 +512,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( mut modify_before_completion, + AfterDeserialization, " A hook called when an execution is completed. This method has the ability to modify and return a new @@ -514,6 +539,7 @@ pub trait Interceptor: std::fmt::Debug { interceptor_trait_fn!( read_after_execution, + AfterDeserialization, " A hook called when an execution is completed. @@ -546,17 +572,23 @@ pub struct Interceptors { } macro_rules! interceptor_impl_fn { - (context, $name:ident) => { - interceptor_impl_fn!(context, $name, $name); + (context, $name:ident, $phase:ident) => { + interceptor_impl_fn!(context, $name, $name, $phase); }; - (mut context, $name:ident) => { - interceptor_impl_fn!(mut context, $name, $name); + (mut context, $name:ident, $phase:ident) => { + interceptor_impl_fn!(mut context, $name, $name, $phase); }; - (context, $outer_name:ident, $inner_name:ident) => { - interceptor_impl_fn!($outer_name, $inner_name(context: &InterceptorContext)); + (context, $outer_name:ident, $inner_name:ident, $phase:ident) => { + interceptor_impl_fn!( + $outer_name, + $inner_name(context: &InterceptorContext<$phase>) + ); }; - (mut context, $outer_name:ident, $inner_name:ident) => { - interceptor_impl_fn!($outer_name, $inner_name(context: &mut InterceptorContext)); + (mut context, $outer_name:ident, $inner_name:ident, $phase:ident) => { + interceptor_impl_fn!( + $outer_name, + $inner_name(context: &mut InterceptorContext<$phase>) + ); }; ($outer_name:ident, $inner_name:ident ($context:ident : $context_ty:ty)) => { pub fn $outer_name( @@ -601,28 +633,46 @@ impl Interceptors { self } - interceptor_impl_fn!(context, client_read_before_execution, read_before_execution); + interceptor_impl_fn!( + context, + client_read_before_execution, + read_before_execution, + BeforeSerialization + ); interceptor_impl_fn!( context, operation_read_before_execution, - read_before_execution + read_before_execution, + BeforeSerialization + ); + interceptor_impl_fn!( + mut context, + modify_before_serialization, + BeforeSerialization + ); + interceptor_impl_fn!(context, read_before_serialization, BeforeSerialization); + interceptor_impl_fn!(context, read_after_serialization, BeforeTransmit); + interceptor_impl_fn!(mut context, modify_before_retry_loop, BeforeTransmit); + interceptor_impl_fn!(context, read_before_attempt, BeforeTransmit); + interceptor_impl_fn!(mut context, modify_before_signing, BeforeTransmit); + interceptor_impl_fn!(context, read_before_signing, BeforeTransmit); + interceptor_impl_fn!(context, read_after_signing, BeforeTransmit); + interceptor_impl_fn!(mut context, modify_before_transmit, BeforeTransmit); + interceptor_impl_fn!(context, read_before_transmit, BeforeTransmit); + interceptor_impl_fn!(context, read_after_transmit, BeforeDeserialization); + interceptor_impl_fn!( + mut context, + modify_before_deserialization, + BeforeDeserialization + ); + interceptor_impl_fn!(context, read_before_deserialization, BeforeDeserialization); + interceptor_impl_fn!(context, read_after_deserialization, AfterDeserialization); + interceptor_impl_fn!( + mut context, + modify_before_attempt_completion, + AfterDeserialization ); - interceptor_impl_fn!(mut context, modify_before_serialization); - interceptor_impl_fn!(context, read_before_serialization); - interceptor_impl_fn!(context, read_after_serialization); - interceptor_impl_fn!(mut context, modify_before_retry_loop); - interceptor_impl_fn!(context, read_before_attempt); - interceptor_impl_fn!(mut context, modify_before_signing); - interceptor_impl_fn!(context, read_before_signing); - interceptor_impl_fn!(context, read_after_signing); - interceptor_impl_fn!(mut context, modify_before_transmit); - interceptor_impl_fn!(context, read_before_transmit); - interceptor_impl_fn!(context, read_after_transmit); - interceptor_impl_fn!(mut context, modify_before_deserialization); - interceptor_impl_fn!(context, read_before_deserialization); - interceptor_impl_fn!(context, read_after_deserialization); - interceptor_impl_fn!(mut context, modify_before_attempt_completion); - interceptor_impl_fn!(context, read_after_attempt); - interceptor_impl_fn!(mut context, modify_before_completion); - interceptor_impl_fn!(context, read_after_execution); + interceptor_impl_fn!(context, read_after_attempt, AfterDeserialization); + interceptor_impl_fn!(mut context, modify_before_completion, AfterDeserialization); + interceptor_impl_fn!(context, read_after_execution, AfterDeserialization); } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs index 61d7c7fd5d1ec1666fdff2b117197ce17f6c6543..309b1ca9d9414c2aefad341f5e97101a7d90435a 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs @@ -3,158 +3,802 @@ * SPDX-License-Identifier: Apache-2.0 */ -use super::InterceptorError; +//! Interceptor context. +//! +//! Interceptors have access to varying pieces of context during the course of an operation. +//! +//! An operation is composed of multiple phases. The initial phase is [`phase::BeforeSerialization`], which +//! has the original input as context. The next phase is [`phase::BeforeTransmit`], which has the serialized +//! request as context. Depending on which hook is being called with the dispatch context, +//! the serialized request may or may not be signed (which should be apparent from the hook name). +//! Following the [`phase::BeforeTransmit`] phase is the [`phase::BeforeDeserialization`] phase, which has +//! the raw response available as context. Finally, the [`phase::AfterDeserialization`] phase +//! has both the raw and parsed response available. +//! +//! To summarize: +//! 1. [`phase::BeforeSerialization`]: Only has the operation input. +//! 2. [`phase::BeforeTransmit`]: Only has the serialized request. +//! 3. [`phase::BeforeDeserialization`]: Has the raw response. +//! 3. [`phase::AfterDeserialization`]: Has the raw response and the parsed response. +//! +//! When implementing hooks, if information from a previous phase is required, then implement +//! an earlier hook to examine that context, and save off any necessary information into the +//! [`crate::config_bag::ConfigBag`] for later hooks to examine. Interior mutability is **NOT** +//! recommended for storing request-specific information in your interceptor implementation. +//! Use the [`crate::config_bag::ConfigBag`] instead. + +use crate::client::interceptors::BoxError; use crate::client::orchestrator::{HttpRequest, HttpResponse}; -use crate::type_erasure::TypeErasedBox; +use crate::config_bag::ConfigBag; +use crate::type_erasure::{TypeErasedBox, TypeErasedError}; +use aws_smithy_http::result::SdkError; pub type Input = TypeErasedBox; pub type Output = TypeErasedBox; -pub type Error = TypeErasedBox; +pub type Error = TypeErasedError; pub type OutputOrError = Result; type Request = HttpRequest; type Response = HttpResponse; +/// Operation phases. +pub mod phase { + use crate::client::interceptors::context::{Error, Output}; + use crate::client::interceptors::BoxError; + use crate::client::orchestrator::HttpResponse; + use aws_smithy_http::result::{ConnectorError, SdkError}; + + macro_rules! impl_phase { + ($phase:ty, $convert_err:ident) => { + impl Phase for $phase { + fn convert_error( + &self, + error: BoxError, + output_or_error: Option>, + response: Option, + ) -> SdkError { + $convert_err(error, output_or_error, response) + } + } + }; + } + + #[doc(hidden)] + pub trait Phase { + fn convert_error( + &self, + error: BoxError, + output_or_error: Option>, + response: Option, + ) -> SdkError; + } + + fn convert_construction_failure( + error: BoxError, + _: Option>, + _: Option, + ) -> SdkError { + SdkError::construction_failure(error) + } + + fn convert_dispatch_error( + error: BoxError, + _: Option>, + response: Option, + ) -> SdkError { + let error = match error.downcast::() { + Ok(connector_error) => { + return SdkError::dispatch_failure(*connector_error); + } + Err(e) => e, + }; + if let Some(response) = response { + SdkError::response_error(error, response) + } else { + SdkError::dispatch_failure(ConnectorError::other(error, None)) + } + } + + fn convert_response_handling_error( + error: BoxError, + output_or_error: Option>, + response: Option, + ) -> SdkError { + match (response, output_or_error) { + (Some(response), Some(Err(error))) => SdkError::service_error(error, response), + (Some(response), _) => SdkError::response_error(error, response), + _ => unreachable!("phase has a response"), + } + } + + /// Represents the phase of an operation prior to serialization. + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct BeforeSerialization; + impl_phase!(BeforeSerialization, convert_construction_failure); + + #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct Serialization; + impl_phase!(Serialization, convert_construction_failure); + + /// Represents the phase of an operation prior to transmitting a request over the network. + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct BeforeTransmit; + impl_phase!(BeforeTransmit, convert_dispatch_error); + + #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct Transmit; + impl_phase!(Transmit, convert_dispatch_error); + + /// Represents the phase of an operation after receiving a response, but before parsing that response. + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct BeforeDeserialization; + impl_phase!(BeforeDeserialization, convert_response_handling_error); + + #[doc(hidden)] // This one isn't exposed in the interceptors, but is used internally + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct Deserialization; + impl_phase!(Deserialization, convert_response_handling_error); + + /// Represents the phase of an operation after parsing a response. + #[derive(Default, Debug)] + #[non_exhaustive] + pub struct AfterDeserialization; + impl_phase!(AfterDeserialization, convert_response_handling_error); +} + /// A container for the data currently available to an interceptor. -pub struct InterceptorContext { - input: Option, - output_or_error: Option, +/// +/// Different context is available based on which phase the operation is currently in. For example, +/// context in the [`phase::BeforeSerialization`] phase won't have a `request` yet since the input hasn't been +/// serialized at that point. But once it gets into the [`phase::BeforeTransmit`] phase, the `request` will be set. +pub struct InterceptorContext { + input: Option, + output_or_error: Option>, request: Option, response: Option, + phase: Phase, } -// TODO(interceptors) we could use types to ensure that people calling methods on interceptor context can't access -// field that haven't been set yet. -impl InterceptorContext { - pub fn new(input: Input) -> Self { - Self { +// +// All phases +// +impl InterceptorContext<(), Input, Output, Error> { + /// Creates a new interceptor context in the [`phase::BeforeSerialization`] phase. + pub fn new( + input: Input, + ) -> InterceptorContext { + InterceptorContext { input: Some(input), output_or_error: None, request: None, response: None, + phase: Default::default(), } } +} +impl InterceptorContext { + /// Decomposes the context into its constituent parts. + #[doc(hidden)] + #[allow(clippy::type_complexity)] + pub fn into_parts( + self, + ) -> ( + Option, + Option>, + Option, + Option, + Phase, + ) { + ( + self.input, + self.output_or_error, + self.request, + self.response, + self.phase, + ) + } +} +// +// BeforeSerialization phase methods +// +impl InterceptorContext { /// Retrieve the input for the operation being invoked. - pub fn input(&self) -> Result<&Input, InterceptorError> { + pub fn input(&self) -> &I { self.input .as_ref() - .ok_or_else(InterceptorError::invalid_input_access) + .expect("input is present in phase::BeforeSerialization") } /// Retrieve the input for the operation being invoked. - pub fn input_mut(&mut self) -> Result<&mut Input, InterceptorError> { + pub fn input_mut(&mut self) -> &mut I { self.input .as_mut() - .ok_or_else(InterceptorError::invalid_input_access) + .expect("input is present in phase::BeforeSerialization") } - /// Takes ownership of the input. + /// Advance to the next phase. #[doc(hidden)] - pub fn take_input(&mut self) -> Option { + pub fn into_serialization_phase(self) -> InterceptorContext { + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: phase::Serialization::default(), + } + } +} + +// +// Serialization phase methods +// +impl InterceptorContext { + /// Takes ownership of the input. + pub fn take_input(&mut self) -> Option { self.input.take() } + pub fn set_request(&mut self, request: Request) { + debug_assert!( + self.request.is_none(), + "called set_request but a request was already set" + ); + self.request = Some(request); + } + + /// Advance to the next phase. + #[doc(hidden)] + pub fn into_before_transmit_phase(self) -> InterceptorContext { + debug_assert!( + self.input.is_none(), + "input must be taken before going into phase::BeforeTransmit" + ); + debug_assert!( + self.request.is_some(), + "request must be set before going into phase::BeforeTransmit" + ); + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: Default::default(), + } + } +} + +// +// BeforeTransmit phase methods +// +impl InterceptorContext { + /// Creates a new interceptor context in the [`phase::BeforeTransmit`] phase. + pub fn new( + input: Option, + request: HttpRequest, + ) -> InterceptorContext { + InterceptorContext { + input, + output_or_error: None, + request: Some(request), + response: None, + phase: Default::default(), + } + } + /// Retrieve the transmittable request for the operation being invoked. /// This will only be available once request marshalling has completed. - pub fn request(&self) -> Result<&Request, InterceptorError> { + pub fn request(&self) -> &Request { self.request .as_ref() - .ok_or_else(InterceptorError::invalid_request_access) + .expect("request populated in phase::BeforeTransmit") } /// Retrieve the transmittable request for the operation being invoked. /// This will only be available once request marshalling has completed. - pub fn request_mut(&mut self) -> Result<&mut Request, InterceptorError> { + pub fn request_mut(&mut self) -> &mut Request { self.request .as_mut() - .ok_or_else(InterceptorError::invalid_request_access) + .expect("request populated in phase::BeforeTransmit") } + #[doc(hidden)] + pub fn into_transmit_phase(self) -> InterceptorContext { + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: Default::default(), + } + } +} + +// +// Transmit phase methods +// +impl InterceptorContext { /// Takes ownership of the request. #[doc(hidden)] - pub fn take_request(&mut self) -> Option { - self.request.take() + pub fn take_request(&mut self) -> Request { + debug_assert!(self.request.is_some()); + self.request + .take() + .expect("take request once during transmit") + } + + #[doc(hidden)] + pub fn set_response(&mut self, response: Response) { + debug_assert!( + self.response.is_none(), + "called set_response but a response was already set" + ); + self.response = Some(response); } - /// Retrieve the response to the transmittable response for the operation - /// being invoked. This will only be available once transmission has - /// completed. - pub fn response(&self) -> Result<&Response, InterceptorError> { + #[doc(hidden)] + pub fn into_before_deserialization_phase( + self, + ) -> InterceptorContext { + debug_assert!( + self.request.is_none(), + "request must be taken before going into phase::BeforeDeserialization" + ); + debug_assert!( + self.response.is_some(), + "response must be set to before going into phase::BeforeDeserialization" + ); + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: Default::default(), + } + } +} + +impl InterceptorContext { + /// Returns the response. + pub fn response(&self) -> &Response { + self.response + .as_ref() + .expect("response set in phase::BeforeDeserialization") + } + + /// Returns a mutable reference to the response. + pub fn response_mut(&mut self) -> &mut Response { + self.response + .as_mut() + .expect("response set in phase::BeforeDeserialization") + } + + #[doc(hidden)] + pub fn into_deserialization_phase(self) -> InterceptorContext { + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: Default::default(), + } + } +} + +impl InterceptorContext { + /// Returns the response. + pub fn response(&self) -> &Response { self.response .as_ref() - .ok_or_else(InterceptorError::invalid_response_access) + .expect("response set in phase::Deserialization") } - /// Retrieve the response to the transmittable response for the operation - /// being invoked. This will only be available once transmission has - /// completed. - pub fn response_mut(&mut self) -> Result<&mut Response, InterceptorError> { + /// Returns a mutable reference to the response. + pub fn response_mut(&mut self) -> &mut Response { self.response .as_mut() - .ok_or_else(InterceptorError::invalid_response_access) + .expect("response set in phase::Deserialization") + } + + #[doc(hidden)] + pub fn set_output_or_error(&mut self, output: Result) { + debug_assert!(self.output_or_error.is_none()); + self.output_or_error = Some(output); } - /// Retrieve the response to the customer. This will only be available - /// once the `response` has been unmarshalled or the attempt/execution has failed. - pub fn output_or_error(&self) -> Result, InterceptorError> { + #[doc(hidden)] + pub fn into_after_deserialization_phase( + self, + ) -> InterceptorContext { + debug_assert!( + self.output_or_error.is_some(), + "output must be set to before going into phase::AfterDeserialization" + ); + InterceptorContext { + input: self.input, + output_or_error: self.output_or_error, + request: self.request, + response: self.response, + phase: Default::default(), + } + } +} + +impl InterceptorContext { + /// Returns the response. + pub fn response(&self) -> &Response { + self.response + .as_ref() + .expect("response set in phase::BeforeDeserialization") + } + + /// Returns a mutable reference to the response. + pub fn response_mut(&mut self) -> &mut Response { + self.response + .as_mut() + .expect("response set in phase::BeforeDeserialization") + } + + /// Returns the deserialized output or error. + pub fn output_or_error(&self) -> Result<&O, &E> { self.output_or_error .as_ref() - .ok_or_else(InterceptorError::invalid_output_access) - .map(|res| res.as_ref()) + .expect("output set in phase::AfterDeserialization") + .as_ref() } - /// Retrieve the response to the customer. This will only be available - /// once the `response` has been unmarshalled or the - /// attempt/execution has failed. - pub fn output_or_error_mut(&mut self) -> Result<&mut Result, InterceptorError> { + /// Returns the mutable reference to the deserialized output or error. + pub fn output_or_error_mut(&mut self) -> &mut Result { self.output_or_error .as_mut() - .ok_or_else(InterceptorError::invalid_output_access) + .expect("output set in phase::AfterDeserialization") } - // There is no set_input method because that can only be set once, during context construction + #[doc(hidden)] + pub fn finalize(self) -> Result> { + self.output_or_error + .expect("output already populated in the response handling phase") + .map_err(|error| { + SdkError::service_error( + error, + self.response + .expect("raw response already populated in the response handling phase"), + ) + }) + } +} - pub fn set_request(&mut self, request: Request) { - if self.request.is_some() { - panic!("Called set_request but a request was already set. This is a bug. Please report it."); +// This isn't great since it relies on a lot of runtime checking, but the +// compiler doesn't exactly make it easy to handle phase changes in a `loop`. +#[doc(hidden)] +pub struct AttemptCheckpoint { + tainted: bool, + checkpointed_request: Option, + before_transmit: Option>, + transmit: Option>, + before_deserialization: Option>, + deserialization: Option>, + after_deserialization: Option>, +} + +impl AttemptCheckpoint { + pub fn new(before_transmit: InterceptorContext) -> Self { + Self { + tainted: false, + checkpointed_request: Self::try_clone(before_transmit.request()), + before_transmit: Some(before_transmit), + transmit: None, + before_deserialization: None, + deserialization: None, + after_deserialization: None, } + } - self.request = Some(request); + pub fn before_transmit(&mut self) -> &mut InterceptorContext { + self.tainted = true; + self.before_transmit + .as_mut() + .expect("must be in the before transmit phase") } - pub fn set_response(&mut self, response: Response) { - if self.response.is_some() { - panic!("Called set_response but a transmit_response was already set. This is a bug. Please report it."); - } + pub fn transmit(&mut self) -> &mut InterceptorContext { + self.transmit + .as_mut() + .expect("must be in the transmit phase") + } - self.response = Some(response); + pub fn before_deser(&mut self) -> &mut InterceptorContext { + self.before_deserialization + .as_mut() + .expect("must be in the before deserialization phase") } - pub fn set_output_or_error(&mut self, output: Result) { - if self.output_or_error.is_some() { - panic!( - "Called set_output but an output was already set. This is a bug. Please report it." - ); + pub fn deser(&mut self) -> &mut InterceptorContext { + self.deserialization + .as_mut() + .expect("must be in the deserialization phase") + } + + pub fn after_deser(&mut self) -> &mut InterceptorContext { + self.after_deserialization + .as_mut() + .expect("must be in the after deserialization phase") + } + + pub fn transition_to_transmit(&mut self) { + self.transmit = Some( + self.before_transmit + .take() + .expect("must be in the before transmit phase") + .into_transmit_phase(), + ); + } + + pub fn transition_to_deserialization(&mut self) { + self.deserialization = Some( + self.before_deserialization + .take() + .expect("must be in the before deserialization phase") + .into_deserialization_phase(), + ) + } + + pub fn transition_to_before_deserialization(&mut self) { + self.before_deserialization = Some( + self.transmit + .take() + .expect("must be in the transmit phase") + .into_before_deserialization_phase(), + ) + } + + pub fn transition_to_after_deserialization(&mut self) { + self.after_deserialization = Some( + self.deserialization + .take() + .expect("must be in the deserialization phase") + .into_after_deserialization_phase(), + ) + } + + // Returns false if rewinding isn't possible + pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> bool { + // If before transmit was never touched, then we don't need to rewind + if !self.tainted { + return true; } + // If checkpointed_request was never set, then this is not a retryable request + if self.checkpointed_request.is_none() { + return false; + } + // Otherwise, rewind back to the beginning of BeforeTransmit + // TODO(enableNewSmithyRuntime): Also rewind the ConfigBag + fn into_input(context: InterceptorContext

) -> Option { + context.into_parts().0 + } + // Take the input from the current phase + let input = None + .or(self.before_transmit.take().map(into_input)) + .or(self.transmit.take().map(into_input)) + .or(self.before_deserialization.take().map(into_input)) + .or(self.deserialization.take().map(into_input)) + .or(self.after_deserialization.take().map(into_input)) + .expect("at least one phase must be in progress"); + let fresh_request = + Self::try_clone(self.checkpointed_request.as_ref().expect("checked above")) + .expect("cloneable request"); + self.before_transmit = Some(InterceptorContext::::new( + input, + fresh_request, + )); + true + } - self.output_or_error = Some(output); + pub fn into_error(self, reason: BoxError) -> SdkError { + fn err( + context: InterceptorContext

, + ) -> Box SdkError> { + Box::new(move |reason| { + let (_input, output_or_error, _request, response, phase) = context.into_parts(); + phase.convert_error(reason, output_or_error, response) + }) + } + // Convert the current phase into an error + (None + .or(self.before_transmit.map(err)) + .or(self.transmit.map(err)) + .or(self.before_deserialization.map(err)) + .or(self.deserialization.map(err)) + .or(self.after_deserialization.map(err)) + .expect("at least one phase must be in progress"))(reason) } - #[doc(hidden)] - pub fn into_parts( - self, - ) -> ( - Option, - Option, - Option, - Option, - ) { - ( - self.input, - self.output_or_error, - self.request, - self.response, + pub fn finalize(self) -> Result> { + self.after_deserialization + .expect("must be in the after deserialization phase") + .finalize() + } + + pub fn try_clone(request: &HttpRequest) -> Option { + let cloned_body = request.body().try_clone()?; + let mut cloned_request = ::http::Request::builder() + .uri(request.uri().clone()) + .method(request.method()); + *cloned_request + .headers_mut() + .expect("builder has not been modified, headers must be valid") = + request.headers().clone(); + Some( + cloned_request + .body(cloned_body) + .expect("a clone of a valid request should be a valid request"), ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::type_erasure::TypedBox; + use aws_smithy_http::body::SdkBody; + use http::header::{AUTHORIZATION, CONTENT_LENGTH}; + use http::{HeaderValue, Uri}; + + #[test] + fn test_success_transitions() { + let input = TypedBox::new("input".to_string()).erase(); + let output = TypedBox::new("output".to_string()).erase(); + + let mut context = InterceptorContext::<()>::new(input); + assert_eq!("input", context.input().downcast_ref::().unwrap()); + context.input_mut(); + + let mut context = context.into_serialization_phase(); + let _ = context.take_input(); + context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + + let mut checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); + checkpoint.before_transmit().request(); + checkpoint.before_transmit().request_mut(); + + checkpoint.transition_to_transmit(); + let _ = checkpoint.transmit().take_request(); + checkpoint + .transmit() + .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + + checkpoint.transition_to_before_deserialization(); + checkpoint.before_deser().response(); + checkpoint.before_deser().response_mut(); + + checkpoint.transition_to_deserialization(); + checkpoint.deser().response(); + checkpoint.deser().response_mut(); + checkpoint.deser().set_output_or_error(Ok(output)); + + checkpoint.transition_to_after_deserialization(); + checkpoint.after_deser().response(); + checkpoint.after_deser().response_mut(); + let _ = checkpoint.after_deser().output_or_error(); + let _ = checkpoint.after_deser().output_or_error_mut(); + + let output = checkpoint.finalize().expect("success"); + assert_eq!("output", output.downcast_ref::().unwrap()); + } + + #[test] + fn test_rewind_for_retry() { + use std::fmt; + #[derive(Debug)] + struct Error; + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("dontcare") + } + } + impl std::error::Error for Error {} + + let mut cfg = ConfigBag::base(); + let input = TypedBox::new("input".to_string()).erase(); + let output = TypedBox::new("output".to_string()).erase(); + let error = TypedBox::new(Error).erase_error(); + + let context = InterceptorContext::<()>::new(input); + assert_eq!("input", context.input().downcast_ref::().unwrap()); + + let mut context = context.into_serialization_phase(); + let _ = context.take_input(); + context.set_request( + http::Request::builder() + .header("test", "the-original-unmutated-request") + .body(SdkBody::empty()) + .unwrap(), + ); + + let mut checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); + + // Modify the test header post-checkpoint to simulate modifying the request for signing or a mutating interceptor + checkpoint + .before_transmit() + .request_mut() + .headers_mut() + .remove("test"); + checkpoint + .before_transmit() + .request_mut() + .headers_mut() + .insert( + "test", + HeaderValue::from_static("request-modified-after-signing"), + ); + + checkpoint.transition_to_transmit(); + let request = checkpoint.transmit().take_request(); + assert_eq!( + "request-modified-after-signing", + request.headers().get("test").unwrap() + ); + checkpoint + .transmit() + .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + + checkpoint.transition_to_before_deserialization(); + checkpoint.transition_to_deserialization(); + checkpoint.deser().set_output_or_error(Err(error)); + + assert!(checkpoint.rewind(&mut cfg)); + + // Now after rewinding, the test header should be its original value + assert_eq!( + "the-original-unmutated-request", + checkpoint + .before_transmit() + .request() + .headers() + .get("test") + .unwrap() + ); + + checkpoint.transition_to_transmit(); + let _ = checkpoint.transmit().take_request(); + checkpoint + .transmit() + .set_response(http::Response::builder().body(SdkBody::empty()).unwrap()); + + checkpoint.transition_to_before_deserialization(); + checkpoint.transition_to_deserialization(); + checkpoint.deser().set_output_or_error(Ok(output)); + + checkpoint.transition_to_after_deserialization(); + + let output = checkpoint.finalize().expect("success"); + assert_eq!("output", output.downcast_ref::().unwrap()); + } + + #[test] + fn try_clone_clones_all_data() { + let request = ::http::Request::builder() + .uri(Uri::from_static("http://www.amazon.com")) + .method("POST") + .header(CONTENT_LENGTH, 456) + .header(AUTHORIZATION, "Token: hello") + .body(SdkBody::from("hello world!")) + .expect("valid request"); + let cloned = AttemptCheckpoint::try_clone(&request).expect("request is cloneable"); + + assert_eq!(&Uri::from_static("http://www.amazon.com"), cloned.uri()); + assert_eq!("POST", cloned.method()); + assert_eq!(2, cloned.headers().len()); + assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),); + assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap()); + assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap()); + } +} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index 30751d2120cd5ec5a6204d28136986a4f244e524..404d03b6975c77fe94fab5247a135c821a99b4e4 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -5,7 +5,7 @@ use crate::client::auth::{AuthOptionResolver, AuthOptionResolverParams, HttpAuthSchemes}; use crate::client::identity::IdentityResolvers; -use crate::client::interceptors::context::{Input, OutputOrError}; +use crate::client::interceptors::context::{Error, Input, Output}; use crate::client::retries::RetryClassifiers; use crate::client::retries::RetryStrategy; use crate::config_bag::ConfigBag; @@ -26,21 +26,17 @@ pub type BoxError = Box; pub type BoxFuture = Pin>>>; pub type Future = NowOrLater, BoxFuture>; -pub trait TraceProbe: Send + Sync + fmt::Debug { - fn dispatch_events(&self); -} - pub trait RequestSerializer: Send + Sync + fmt::Debug { fn serialize_input(&self, input: Input) -> Result; } pub trait ResponseDeserializer: Send + Sync + fmt::Debug { - fn deserialize_streaming(&self, response: &mut HttpResponse) -> Option { + fn deserialize_streaming(&self, response: &mut HttpResponse) -> Option> { let _ = response; None } - fn deserialize_nonstreaming(&self, response: &HttpResponse) -> OutputOrError; + fn deserialize_nonstreaming(&self, response: &HttpResponse) -> Result; } pub trait Connection: Send + Sync + fmt::Debug { @@ -138,9 +134,6 @@ pub trait ConfigBagAccessors { fn retry_strategy(&self) -> &dyn RetryStrategy; fn set_retry_strategy(&mut self, retry_strategy: impl RetryStrategy + 'static); - fn trace_probe(&self) -> &dyn TraceProbe; - fn set_trace_probe(&mut self, trace_probe: impl TraceProbe + 'static); - fn request_time(&self) -> Option; fn set_request_time(&mut self, request_time: RequestTime); @@ -263,16 +256,6 @@ impl ConfigBagAccessors for ConfigBag { self.put::>(Box::new(retry_strategy)); } - fn trace_probe(&self) -> &dyn TraceProbe { - &**self - .get::>() - .expect("missing trace probe") - } - - fn set_trace_probe(&mut self, trace_probe: impl TraceProbe + 'static) { - self.put::>(Box::new(trace_probe)); - } - fn request_time(&self) -> Option { self.get::().cloned() } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs index 1331317afe893ca28f4111d14b914be633826a7b..0d6678d359e7a9a9bf477876b0989a10e78103f8 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/retries.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use crate::client::interceptors::context::phase::AfterDeserialization; use crate::client::interceptors::context::Error; use crate::client::interceptors::InterceptorContext; use crate::client::orchestrator::BoxError; @@ -23,7 +24,7 @@ pub trait RetryStrategy: Send + Sync + Debug { fn should_attempt_retry( &self, - context: &InterceptorContext, + context: &InterceptorContext, cfg: &ConfigBag, ) -> Result; } diff --git a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs index 4c20c6c4d482c307f8812811542b79f2ace9609b..f2269c1c444e24c460f59d91d72ece67a5a8dca7 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs @@ -3,7 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -use std::any::{type_name, Any}; +use std::any::Any; +use std::error::Error as StdError; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; @@ -62,13 +63,24 @@ where } } +impl TypedBox +where + T: StdError + fmt::Debug + Send + Sync + 'static, +{ + /// Converts `TypedBox` to a `TypeErasedError` where `T` implements `Error`. + pub fn erase_error(self) -> TypeErasedError { + let inner = self.inner.downcast::().expect("typechecked"); + TypeErasedError::new(inner) + } +} + impl fmt::Debug for TypedBox where T: Send + Sync + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("TypedBox:")?; - (self.inner.debug)(&self.inner, f) + (self.inner.debug)(&self.inner.field, f) } } @@ -89,56 +101,125 @@ impl DerefMut for TypedBox { /// A new-type around `Box` pub struct TypeErasedBox { field: Box, - #[allow(dead_code)] - type_name: &'static str, #[allow(clippy::type_complexity)] - debug: Box) -> fmt::Result + Send + Sync>, + debug: Box< + dyn Fn(&Box, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + >, } impl fmt::Debug for TypeErasedBox { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("TypeErasedBox:")?; - (self.debug)(self, f) + (self.debug)(&self.field, f) } } impl TypeErasedBox { pub fn new(value: T) -> Self { - let debug = |value: &TypeErasedBox, f: &mut fmt::Formatter<'_>| { + let debug = |value: &Box, f: &mut fmt::Formatter<'_>| { fmt::Debug::fmt(value.downcast_ref::().expect("typechecked"), f) }; - let name = type_name::(); Self { field: Box::new(value), - type_name: name, debug: Box::new(debug), } } // Downcast into a `Box`, or return `Self` if it is not a `T`. pub fn downcast(self) -> Result, Self> { - let TypeErasedBox { + let TypeErasedBox { field, debug } = self; + field.downcast().map_err(|field| Self { field, debug }) + } + + /// Downcast as a `&T`, or return `None` if it is not a `T`. + pub fn downcast_ref(&self) -> Option<&T> { + self.field.downcast_ref() + } + + /// Downcast as a `&mut T`, or return `None` if it is not a `T`. + pub fn downcast_mut(&mut self) -> Option<&mut T> { + self.field.downcast_mut() + } +} + +impl From for TypeErasedBox { + fn from(value: TypeErasedError) -> Self { + TypeErasedBox { + field: value.field, + debug: value.debug, + } + } +} + +/// A new-type around `Box` that also implements `Error` +pub struct TypeErasedError { + field: Box, + #[allow(clippy::type_complexity)] + debug: Box< + dyn Fn(&Box, &mut fmt::Formatter<'_>) -> fmt::Result + Send + Sync, + >, + #[allow(clippy::type_complexity)] + as_error: Box Fn(&'a TypeErasedError) -> &'a (dyn StdError) + Send + Sync>, +} + +impl fmt::Debug for TypeErasedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("TypeErasedError:")?; + (self.debug)(&self.field, f) + } +} + +impl fmt::Display for TypeErasedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt((self.as_error)(self), f) + } +} + +impl StdError for TypeErasedError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + (self.as_error)(self).source() + } +} + +impl TypeErasedError { + pub fn new(value: T) -> Self { + let debug = |value: &Box, f: &mut fmt::Formatter<'_>| { + fmt::Debug::fmt(value.downcast_ref::().expect("typechecked"), f) + }; + Self { + field: Box::new(value), + debug: Box::new(debug), + as_error: Box::new(|value: &TypeErasedError| { + value.downcast_ref::().expect("typechecked") as _ + }), + } + } + + // Downcast into a `Box`, or return `Self` if it is not a `T`. + pub fn downcast( + self, + ) -> Result, Self> { + let TypeErasedError { field, - type_name, debug, + as_error, } = self; - match field.downcast() { - Ok(t) => Ok(t), - Err(s) => Err(Self { - field: s, - type_name, - debug, - }), - } + field.downcast().map_err(|field| Self { + field, + debug, + as_error, + }) } /// Downcast as a `&T`, or return `None` if it is not a `T`. - pub fn downcast_ref(&self) -> Option<&T> { + pub fn downcast_ref(&self) -> Option<&T> { self.field.downcast_ref() } /// Downcast as a `&mut T`, or return `None` if it is not a `T`. - pub fn downcast_mut(&mut self) -> Option<&mut T> { + pub fn downcast_mut( + &mut self, + ) -> Option<&mut T> { self.field.downcast_mut() } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index c0f4b4f06b86eb6586640fa77182ef092e7fc712..b174122dc87f6fc6e9e0e8fa9175da421f8eef13 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -6,10 +6,12 @@ use self::auth::orchestrate_auth; use crate::client::orchestrator::endpoints::orchestrate_endpoint; use crate::client::orchestrator::http::read_body; -use crate::client::orchestrator::phase::Phase; use crate::client::timeout::{MaybeTimeout, ProvideMaybeTimeoutConfig, TimeoutKind}; use aws_smithy_http::result::SdkError; -use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output}; +use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeSerialization; +use aws_smithy_runtime_api::client::interceptors::context::{ + AttemptCheckpoint, Error, Input, Output, +}; use aws_smithy_runtime_api::client::interceptors::{InterceptorContext, Interceptors}; use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpResponse}; use aws_smithy_runtime_api::client::retries::ShouldAttempt; @@ -21,18 +23,46 @@ mod auth; /// Defines types that implement a trait for endpoint resolution pub mod endpoints; mod http; -pub(self) mod phase; -pub async fn invoke( - input: Input, - runtime_plugins: &RuntimePlugins, -) -> Result> { - invoke_pre_config(input, runtime_plugins) - .instrument(debug_span!("invoke")) - .await +#[doc(hidden)] +#[macro_export] +macro_rules! handle_err { + ([$checkpoint:expr] => $expr:expr) => { + match $expr { + Ok(ok) => ok, + Err(err) => { + return Err($checkpoint.into_error(err.into())); + } + } + }; + ($ctx:expr => $expr:expr) => { + match $expr { + Ok(ok) => ok, + Err(err) => { + use aws_smithy_runtime_api::client::interceptors::context::phase::Phase; + let (_input, output_or_error, _request, response, phase) = $ctx.into_parts(); + return Err(phase.convert_error(err.into(), output_or_error, response)); + } + } + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! bail { + ([$checkpoint:expr], $reason:expr) => {{ + return Err($checkpoint.into_error($reason.into())); + }}; + ($ctx:expr, $reason:expr) => {{ + use aws_smithy_runtime_api::client::interceptors::context::phase::Phase; + let reason: BoxError = $reason.into(); + let (_input, output_or_error, _request, response, phase) = $ctx.into_parts(); + return Err(phase.convert_error(reason, output_or_error, response)); + }}; } -async fn invoke_pre_config( +#[tracing::instrument(skip_all)] +pub async fn invoke( input: Input, runtime_plugins: &RuntimePlugins, ) -> Result> { @@ -40,15 +70,14 @@ async fn invoke_pre_config( let cfg = &mut cfg; let mut interceptors = Interceptors::new(); + let context = InterceptorContext::<()>::new(input); - let context = Phase::construction(InterceptorContext::new(input)) - // Client configuration - .include(|_| runtime_plugins.apply_client_configuration(cfg, &mut interceptors))? - .include(|ctx| interceptors.client_read_before_execution(ctx, cfg))? - // Operation configuration - .include(|_| runtime_plugins.apply_operation_configuration(cfg, &mut interceptors))? - .include(|ctx| interceptors.operation_read_before_execution(ctx, cfg))? - .finish(); + // Client configuration + handle_err!(context => runtime_plugins.apply_client_configuration(cfg, &mut interceptors)); + handle_err!(context => interceptors.client_read_before_execution(&context, cfg)); + // Operation configuration + handle_err!(context => runtime_plugins.apply_operation_configuration(cfg, &mut interceptors)); + handle_err!(context => interceptors.operation_read_before_execution(&context, cfg)); let operation_timeout_config = cfg.maybe_timeout_config(TimeoutKind::Operation); invoke_post_config(cfg, context, interceptors) @@ -58,26 +87,26 @@ async fn invoke_pre_config( async fn invoke_post_config( cfg: &mut ConfigBag, - context: InterceptorContext, + mut before_serialization: InterceptorContext, interceptors: Interceptors, ) -> Result> { - let context = Phase::construction(context) - // Before serialization - .include(|ctx| interceptors.read_before_serialization(ctx, cfg))? - .include_mut(|ctx| interceptors.modify_before_serialization(ctx, cfg))? - // Serialization - .include_mut(|ctx| { - let request_serializer = cfg.request_serializer(); - let request = request_serializer - .serialize_input(ctx.take_input().expect("input set at this point"))?; - ctx.set_request(request); - Result::<(), BoxError>::Ok(()) - })? - // After serialization - .include(|ctx| interceptors.read_after_serialization(ctx, cfg))? - // Before retry loop - .include_mut(|ctx| interceptors.modify_before_retry_loop(ctx, cfg))? - .finish(); + // Before serialization + handle_err!(before_serialization => interceptors.read_before_serialization(&before_serialization, cfg)); + handle_err!(before_serialization => interceptors.modify_before_serialization(&mut before_serialization, cfg)); + + // Serialization + let mut serialization = before_serialization.into_serialization_phase(); + { + let request_serializer = cfg.request_serializer(); + let request = handle_err!(serialization => request_serializer + .serialize_input(serialization.take_input().expect("input set at this point"))); + serialization.set_request(request); + } + + // Before transmit + let mut before_transmit = serialization.into_before_transmit_phase(); + handle_err!(before_transmit => interceptors.read_after_serialization(&before_transmit, cfg)); + handle_err!(before_transmit => interceptors.modify_before_retry_loop(&mut before_transmit, cfg)); { let retry_strategy = cfg.retry_strategy(); @@ -86,97 +115,89 @@ async fn invoke_post_config( Ok(ShouldAttempt::Yes) => {} // No, this request shouldn't be sent Ok(ShouldAttempt::No) => { - return Err(Phase::dispatch(context).fail( - "The retry strategy indicates that an initial request shouldn't be made, but it didn't specify why.", - )) + bail!(before_transmit, "The retry strategy indicates that an initial request shouldn't be made, but it didn't specify why."); } // No, we shouldn't make a request because... - Err(err) => return Err(Phase::dispatch(context).fail(err)), + Err(err) => bail!(before_transmit, err), Ok(ShouldAttempt::YesAfterDelay(_)) => { unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.") } } } - let mut context = context; - let handling_phase = loop { + let mut checkpoint = AttemptCheckpoint::new(before_transmit); + checkpoint = loop { + if !checkpoint.rewind(cfg) { + break checkpoint; + } let attempt_timeout_config = cfg.maybe_timeout_config(TimeoutKind::OperationAttempt); - let dispatch_phase = Phase::dispatch(context); - context = make_an_attempt(dispatch_phase, cfg, &interceptors) - .instrument(debug_span!("make_an_attempt")) + + checkpoint = make_an_attempt(checkpoint, cfg, &interceptors) .maybe_timeout_with_config(attempt_timeout_config) - .await? - .include(|ctx| interceptors.read_after_attempt(ctx, cfg))? - .include_mut(|ctx| interceptors.modify_before_attempt_completion(ctx, cfg))? - .finish(); + .await?; + handle_err!([checkpoint] => interceptors.read_after_attempt(checkpoint.after_deser(), cfg)); + handle_err!([checkpoint] => interceptors.modify_before_attempt_completion(checkpoint.after_deser(), cfg)); let retry_strategy = cfg.retry_strategy(); - match retry_strategy.should_attempt_retry(&context, cfg) { + match retry_strategy.should_attempt_retry(checkpoint.after_deser(), cfg) { // Yes, let's retry the request Ok(ShouldAttempt::Yes) => continue, // No, this request shouldn't be retried Ok(ShouldAttempt::No) => {} Ok(ShouldAttempt::YesAfterDelay(_delay)) => { + // TODO(enableNewSmithyRuntime): implement retries with explicit delay todo!("implement retries with an explicit delay.") } // I couldn't determine if the request should be retried because an error occurred. - Err(err) => { - return Err(Phase::response_handling(context).fail(err)); - } + Err(err) => bail!([checkpoint], err), } - let handling_phase = Phase::response_handling(context) - .include_mut(|ctx| interceptors.modify_before_completion(ctx, cfg))?; - cfg.trace_probe().dispatch_events(); - - break handling_phase.include(|ctx| interceptors.read_after_execution(ctx, cfg))?; + break checkpoint; }; - handling_phase.finalize() + handle_err!([checkpoint] => interceptors.modify_before_completion(checkpoint.after_deser(), cfg)); + handle_err!([checkpoint] => interceptors.read_after_execution(checkpoint.after_deser(), cfg)); + + checkpoint.finalize() } // Making an HTTP request can fail for several reasons, but we still need to // call lifecycle events when that happens. Therefore, we define this // `make_an_attempt` function to make error handling simpler. +#[tracing::instrument(skip_all)] async fn make_an_attempt( - dispatch_phase: Phase, + mut checkpoint: AttemptCheckpoint, cfg: &mut ConfigBag, interceptors: &Interceptors, -) -> Result> { - let dispatch_phase = dispatch_phase - .include(|ctx| interceptors.read_before_attempt(ctx, cfg))? - .include_mut(|ctx| orchestrate_endpoint(ctx, cfg))? - .include_mut(|ctx| interceptors.modify_before_signing(ctx, cfg))? - .include(|ctx| interceptors.read_before_signing(ctx, cfg))?; +) -> Result> { + handle_err!([checkpoint] => interceptors.read_before_attempt(checkpoint.before_transmit(), cfg)); + handle_err!([checkpoint] => orchestrate_endpoint(checkpoint.before_transmit(), cfg)); + handle_err!([checkpoint] => interceptors.modify_before_signing(checkpoint.before_transmit(), cfg)); + handle_err!([checkpoint] => interceptors.read_before_signing(checkpoint.before_transmit(), cfg)); - let dispatch_phase = orchestrate_auth(dispatch_phase, cfg).await?; + checkpoint = orchestrate_auth(checkpoint, cfg).await?; - let mut context = dispatch_phase - .include(|ctx| interceptors.read_after_signing(ctx, cfg))? - .include_mut(|ctx| interceptors.modify_before_transmit(ctx, cfg))? - .include(|ctx| interceptors.read_before_transmit(ctx, cfg))? - .finish(); + handle_err!([checkpoint] => interceptors.read_after_signing(checkpoint.before_transmit(), cfg)); + handle_err!([checkpoint] => interceptors.modify_before_transmit(checkpoint.before_transmit(), cfg)); + handle_err!([checkpoint] => interceptors.read_before_transmit(checkpoint.before_transmit(), cfg)); // The connection consumes the request but we need to keep a copy of it // within the interceptor context, so we clone it here. - let call_result = { - let request = context.take_request().expect("request has been set"); - let connection = cfg.connection(); - connection.call(request).await - }; - - let mut context = Phase::dispatch(context) - .include_mut(move |ctx| { - ctx.set_response(call_result?); - Result::<(), BoxError>::Ok(()) - })? - .include(|ctx| interceptors.read_after_transmit(ctx, cfg))? - .include_mut(|ctx| interceptors.modify_before_deserialization(ctx, cfg))? - .include(|ctx| interceptors.read_before_deserialization(ctx, cfg))? - .finish(); - - let output_or_error = { - let response = context.response_mut().expect("response has been set"); + checkpoint.transition_to_transmit(); + let call_result = handle_err!([checkpoint] => { + let request = checkpoint.transmit().take_request(); + cfg.connection().call(request).await + }); + checkpoint.transmit().set_response(call_result); + checkpoint.transition_to_before_deserialization(); + + handle_err!([checkpoint] => interceptors.read_after_transmit(checkpoint.before_deser(), cfg)); + handle_err!([checkpoint] => interceptors.modify_before_deserialization(checkpoint.before_deser(), cfg)); + handle_err!([checkpoint] => interceptors.read_before_deserialization(checkpoint.before_deser(), cfg)); + + checkpoint.transition_to_deserialization(); + let output_or_error = handle_err!([checkpoint] => { + let response = checkpoint.deser().response_mut(); let response_deserializer = cfg.response_deserializer(); match response_deserializer.deserialize_streaming(response) { Some(output_or_error) => Ok(output_or_error), @@ -185,12 +206,12 @@ async fn make_an_attempt( .await .map(|_| response_deserializer.deserialize_nonstreaming(response)), } - }; + }); + + checkpoint.deser().set_output_or_error(output_or_error); + + checkpoint.transition_to_after_deserialization(); + handle_err!([checkpoint] => interceptors.read_after_deserialization(checkpoint.after_deser(), cfg)); - Phase::response_handling(context) - .include_mut(move |ctx| { - ctx.set_output_or_error(output_or_error?); - Result::<(), BoxError>::Ok(()) - })? - .include(|ctx| interceptors.read_after_deserialization(ctx, cfg)) + Ok(checkpoint) } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs index 9818593f059bc6e8e20f37515a00daf0f5d8273c..374795d3d4220942af62c438cdf775f1842cbffc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/auth.rs @@ -3,16 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -use super::phase::Phase; +use crate::client::orchestrator::AttemptCheckpoint; +use crate::{bail, handle_err}; use aws_smithy_http::result::SdkError; use aws_smithy_runtime_api::client::interceptors::context::Error; use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpResponse}; use aws_smithy_runtime_api::config_bag::ConfigBag; pub(super) async fn orchestrate_auth( - dispatch_phase: Phase, + mut checkpoint: AttemptCheckpoint, cfg: &ConfigBag, -) -> Result> { +) -> Result> { fn construction_failure(err: impl Into) -> SdkError { SdkError::construction_failure(err) } @@ -39,18 +40,17 @@ pub(super) async fn orchestrate_auth( .resolve_identity(cfg) .await .map_err(construction_failure)?; - return dispatch_phase.include_mut(|ctx| { - let request = ctx.request_mut()?; - request_signer.sign_request(request, &identity, cfg)?; - Result::<_, BoxError>::Ok(()) - }); + let request = checkpoint.before_transmit().request_mut(); + handle_err!([checkpoint] => request_signer.sign_request(request, &identity, cfg)); + return Ok(checkpoint); } } } - Err(construction_failure( - "no auth scheme matched auth options. This is a bug. Please file an issue.", - )) + bail!( + [checkpoint], + "no auth scheme matched auth options. This is a bug. Please file an issue." + ); } #[cfg(test)] @@ -117,8 +117,10 @@ mod tests { } let input = TypedBox::new("doesnt-matter").erase(); - let mut context = InterceptorContext::new(input); + let mut context = InterceptorContext::<()>::new(input).into_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = context.take_input(); + let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); let mut cfg = ConfigBag::base(); cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter")); @@ -134,17 +136,13 @@ mod tests { .build(), ); - let phase = Phase::dispatch(context); - let context = orchestrate_auth(phase, &cfg) - .await - .expect("success") - .finish(); + let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); assert_eq!( "success!", - context + checkpoint + .before_transmit() .request() - .unwrap() .headers() .get("Authorization") .unwrap() @@ -160,8 +158,11 @@ mod tests { }; use aws_smithy_runtime_api::client::identity::http::{Login, Token}; - let mut context = InterceptorContext::new(TypedBox::new("doesnt-matter").erase()); + let mut context = InterceptorContext::<()>::new(TypedBox::new("doesnt-matter").erase()) + .into_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = context.take_input(); + let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); let mut cfg = ConfigBag::base(); cfg.set_auth_option_resolver_params(AuthOptionResolverParams::new("doesntmatter")); @@ -183,18 +184,14 @@ mod tests { .build(), ); - let phase = Phase::dispatch(context); - let context = orchestrate_auth(phase, &cfg) - .await - .expect("success") - .finish(); + let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); assert_eq!( // "YTpi" == "a:b" in base64 "Basic YTpi", - context + checkpoint + .before_transmit() .request() - .unwrap() .headers() .get("Authorization") .unwrap() @@ -207,19 +204,18 @@ mod tests { .build(), ); - let mut context = InterceptorContext::new(TypedBox::new("doesnt-matter").erase()); + let mut context = InterceptorContext::<()>::new(TypedBox::new("doesnt-matter").erase()) + .into_serialization_phase(); context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap()); + let _ = context.take_input(); + let checkpoint = AttemptCheckpoint::new(context.into_before_transmit_phase()); - let context = orchestrate_auth(Phase::dispatch(context), &cfg) - .await - .expect("success") - .finish(); - + let mut checkpoint = orchestrate_auth(checkpoint, &cfg).await.expect("success"); assert_eq!( "Bearer t", - context + checkpoint + .before_transmit() .request() - .unwrap() .headers() .get("Authorization") .unwrap() diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs index 89570c1ef65ddfa552521d1a2fcf0b58ad3c4b34..08551f5bee6129208fb50cd57976c93ce011f26a 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs @@ -7,6 +7,7 @@ use aws_smithy_http::endpoint::error::ResolveEndpointError; use aws_smithy_http::endpoint::{ apply_endpoint, EndpointPrefix, ResolveEndpoint, SharedEndpointResolver, }; +use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::{ BoxError, ConfigBagAccessors, EndpointResolver, EndpointResolverParams, HttpRequest, @@ -112,12 +113,12 @@ where } pub(super) fn orchestrate_endpoint( - ctx: &mut InterceptorContext, + ctx: &mut InterceptorContext, cfg: &ConfigBag, ) -> Result<(), BoxError> { let params = cfg.endpoint_resolver_params(); let endpoint_prefix = cfg.get::(); - let request = ctx.request_mut()?; + let request = ctx.request_mut(); let endpoint_resolver = cfg.endpoint_resolver(); endpoint_resolver.resolve_and_apply_endpoint(params, endpoint_prefix, request)?; diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/phase.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/phase.rs deleted file mode 100644 index 9e101050b7ecbd08b42891aa6ddc409841dc1f4b..0000000000000000000000000000000000000000 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/phase.rs +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_http::result::{ConnectorError, SdkError}; -use aws_smithy_runtime_api::client::interceptors::context::{Error, Output}; -use aws_smithy_runtime_api::client::interceptors::InterceptorContext; -use aws_smithy_runtime_api::client::orchestrator::{BoxError, HttpResponse}; - -#[derive(Copy, Clone, Eq, PartialEq)] -enum OrchestrationPhase { - Construction, - Dispatch, - ResponseHandling, -} - -pub(super) struct Phase { - phase: OrchestrationPhase, - context: InterceptorContext, -} - -impl Phase { - pub(crate) fn construction(context: InterceptorContext) -> Self { - Self::start(OrchestrationPhase::Construction, context) - } - pub(crate) fn dispatch(context: InterceptorContext) -> Self { - Self::start(OrchestrationPhase::Dispatch, context) - } - pub(crate) fn response_handling(context: InterceptorContext) -> Self { - Self::start(OrchestrationPhase::ResponseHandling, context) - } - - fn start(phase: OrchestrationPhase, context: InterceptorContext) -> Self { - match phase { - OrchestrationPhase::Construction => {} - OrchestrationPhase::Dispatch => {} - OrchestrationPhase::ResponseHandling => debug_assert!(context.response().is_ok()), - } - Self { phase, context } - } - - pub(crate) fn include_mut>( - mut self, - c: impl FnOnce(&mut InterceptorContext) -> Result<(), E>, - ) -> Result> { - match c(&mut self.context) { - Ok(_) => Ok(self), - Err(e) => Err(self.fail(e)), - } - } - - pub(crate) fn include>( - self, - c: impl FnOnce(&InterceptorContext) -> Result<(), E>, - ) -> Result> { - match c(&self.context) { - Ok(_) => Ok(self), - Err(e) => Err(self.fail(e)), - } - } - - pub(crate) fn fail(self, e: impl Into) -> SdkError { - self.into_sdk_error(e.into()) - } - - pub(crate) fn finalize(self) -> Result> { - debug_assert!(self.phase == OrchestrationPhase::ResponseHandling); - let (_input, output_or_error, _request, response) = self.context.into_parts(); - match output_or_error { - Some(output_or_error) => match output_or_error { - Ok(output) => Ok(output), - Err(error) => Err(SdkError::service_error( - error, - response.expect("response must be set by this point"), - )), - }, - None => unreachable!("phase can't get this far without bubbling up a failure"), - } - } - - fn into_sdk_error(self, e: BoxError) -> SdkError { - let e = match e.downcast::() { - Ok(connector_error) => { - debug_assert!( - self.phase == OrchestrationPhase::Dispatch, - "connector errors should only occur during the dispatch phase" - ); - return SdkError::dispatch_failure(*connector_error); - } - Err(e) => e, - }; - let (_input, output_or_error, _request, response) = self.context.into_parts(); - match self.phase { - OrchestrationPhase::Construction => SdkError::construction_failure(e), - OrchestrationPhase::Dispatch => { - if let Some(response) = response { - SdkError::response_error(e, response) - } else { - SdkError::dispatch_failure(ConnectorError::other(e, None)) - } - } - OrchestrationPhase::ResponseHandling => match (response, output_or_error) { - (Some(response), Some(Err(error))) => SdkError::service_error(error, response), - (Some(response), _) => SdkError::response_error(e, response), - _ => unreachable!("response handling phase at least has a response"), - }, - } - } - - pub(crate) fn finish(self) -> InterceptorContext { - self.context - } -} diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs index f415af36912da423ac0ed39dc08400e33a0f1595..3b1b56ffb759ec7f56f41b15b26bf991a68f755b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/never.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_runtime_api::client::interceptors::context::phase::AfterDeserialization; use aws_smithy_runtime_api::client::interceptors::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::BoxError; use aws_smithy_runtime_api::client::retries::{RetryStrategy, ShouldAttempt}; @@ -24,7 +25,7 @@ impl RetryStrategy for NeverRetryStrategy { fn should_attempt_retry( &self, - _context: &InterceptorContext, + _context: &InterceptorContext, _cfg: &ConfigBag, ) -> Result { Ok(ShouldAttempt::No)