From c6e537969eb8ce473b0ec9f3fd970c7bcd436f21 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Fri, 26 May 2023 09:46:39 -0700 Subject: [PATCH] Fix event streams in the orchestrator implementation (#2673) ## Motivation and Context This PR gets event streams working in the client orchestrator implementation, and depends on #2671. The orchestrator's `TypeErasedBox` enforces a `Send + Sync` requirement on inputs and outputs. For the most part, this isn't an issue since almost all generated inputs/outputs are `Send + Sync`, but it turns out the `EventStreamSender` wasn't `Sync` due to an omission of the `Sync` bound. Thus, this PR is a breaking change, as it adds a `Sync` requirement for anyone who passes a stream to an event stream operation. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --- CHANGELOG.next.toml | 12 + aws/rust-runtime/aws-runtime/Cargo.toml | 5 + aws/rust-runtime/aws-runtime/src/auth.rs | 470 +------------- .../aws-runtime/src/auth/sigv4.rs | 609 ++++++++++++++++++ aws/rust-runtime/aws-sigv4/src/lib.rs | 12 + .../smithy/rustsdk/SigV4AuthDecorator.kt | 8 +- .../protocol/MakeOperationGenerator.kt | 8 +- .../protocol/RequestSerializerGenerator.kt | 13 +- .../protocols/HttpBoundProtocolGenerator.kt | 16 +- .../protocol/ProtocolTestGeneratorTest.kt | 10 +- .../generators/protocol/ProtocolGenerator.kt | 15 +- .../HttpBoundProtocolPayloadGenerator.kt | 39 +- .../src/event_stream/sender.rs | 19 +- .../src/client/orchestrator.rs | 12 +- .../src/client/orchestrator.rs | 2 +- .../src/client/test_util/serializer.rs | 6 +- 16 files changed, 757 insertions(+), 499 deletions(-) create mode 100644 aws/rust-runtime/aws-runtime/src/auth/sigv4.rs diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 86fe2d7fe..f81531ff7 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -50,3 +50,15 @@ See the changes in https://github.com/awslabs/smithy-rs/pull/2671 for an example references = ["smithy-rs#2671"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } author = "jdisanti" + +[[aws-sdk-rust]] +message = "For event stream operations such as S3 SelectObjectContent or Transcribe StartStreamTranscription, the `EventStreamSender` in the input now requires the passed in `Stream` impl to implement `Sync`." +references = ["smithy-rs#2673"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" + +[[smithy-rs]] +message = "For event stream operations, the `EventStreamSender` in inputs/outputs now requires the passed in `Stream` impl to implement `Sync`." +references = ["smithy-rs#2673"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all"} +author = "jdisanti" diff --git a/aws/rust-runtime/aws-runtime/Cargo.toml b/aws/rust-runtime/aws-runtime/Cargo.toml index cd0e4bc0c..c5bb21ca1 100644 --- a/aws/rust-runtime/aws-runtime/Cargo.toml +++ b/aws/rust-runtime/aws-runtime/Cargo.toml @@ -7,10 +7,14 @@ edition = "2021" license = "Apache-2.0" repository = "https://github.com/awslabs/smithy-rs" +[features] +event-stream = ["dep:aws-smithy-eventstream", "aws-sigv4/sign-eventstream"] + [dependencies] aws-credential-types = { path = "../aws-credential-types" } aws-http = { path = "../aws-http" } aws-sigv4 = { path = "../aws-sigv4" } +aws-smithy-eventstream = { path = "../../../rust-runtime/aws-smithy-eventstream", optional = true } aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" } aws-smithy-runtime = { path = "../../../rust-runtime/aws-smithy-runtime" } aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api" } @@ -22,6 +26,7 @@ tracing = "0.1" uuid = { version = "1", features = ["v4", "fast-rng"] } [dev-dependencies] +aws-credential-types = { path = "../aws-credential-types", features = ["test-util"] } aws-smithy-protocol-test = { path = "../../../rust-runtime/aws-smithy-protocol-test" } proptest = "1" serde = { version = "1", features = ["derive"]} diff --git a/aws/rust-runtime/aws-runtime/src/auth.rs b/aws/rust-runtime/aws-runtime/src/auth.rs index f436a5f02..149466a58 100644 --- a/aws/rust-runtime/aws-runtime/src/auth.rs +++ b/aws/rust-runtime/aws-runtime/src/auth.rs @@ -4,472 +4,4 @@ */ /// Auth implementations for SigV4. -pub mod sigv4 { - use aws_credential_types::Credentials; - use aws_sigv4::http_request::{ - sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableBody, - SignableRequest, SignatureLocation, SigningParams, SigningSettings, - UriPathNormalizationMode, - }; - use aws_smithy_runtime_api::client::auth::{ - AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner, - }; - use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers}; - use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest}; - use aws_smithy_runtime_api::config_bag::ConfigBag; - use aws_smithy_types::Document; - use aws_types::region::{Region, SigningRegion}; - use aws_types::SigningService; - use std::borrow::Cow; - use std::error::Error as StdError; - use std::fmt; - use std::time::{Duration, SystemTime}; - - const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \ - `expires_in` duration because the credentials used to sign it will expire first."; - - /// Auth scheme ID for SigV4. - pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4"); - - struct EndpointAuthSchemeConfig { - signing_region_override: Option, - signing_service_override: Option, - } - - #[derive(Debug)] - enum SigV4SigningError { - MissingOperationSigningConfig, - MissingSigningRegion, - MissingSigningService, - WrongIdentityType(Identity), - BadTypeInEndpointAuthSchemeConfig(&'static str), - } - - impl fmt::Display for SigV4SigningError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use SigV4SigningError::*; - let mut w = |s| f.write_str(s); - match self { - MissingOperationSigningConfig => w("missing operation signing config for SigV4"), - MissingSigningRegion => w("missing signing region for SigV4 signing"), - MissingSigningService => w("missing signing service for SigV4 signing"), - WrongIdentityType(identity) => { - write!(f, "wrong identity type for SigV4: {identity:?}") - } - BadTypeInEndpointAuthSchemeConfig(field_name) => { - write!( - f, - "unexpected type for `{field_name}` in endpoint auth scheme config", - ) - } - } - } - } - - impl StdError for SigV4SigningError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match self { - Self::MissingOperationSigningConfig => None, - Self::MissingSigningRegion => None, - Self::MissingSigningService => None, - Self::WrongIdentityType(_) => None, - Self::BadTypeInEndpointAuthSchemeConfig(_) => None, - } - } - } - - /// SigV4 auth scheme. - #[derive(Debug, Default)] - pub struct SigV4HttpAuthScheme { - signer: SigV4HttpRequestSigner, - } - - impl SigV4HttpAuthScheme { - /// Creates a new `SigV4HttpAuthScheme`. - pub fn new() -> Self { - Default::default() - } - } - - impl HttpAuthScheme for SigV4HttpAuthScheme { - fn scheme_id(&self) -> AuthSchemeId { - SCHEME_ID - } - - fn identity_resolver<'a>( - &self, - identity_resolvers: &'a IdentityResolvers, - ) -> Option<&'a dyn IdentityResolver> { - identity_resolvers.identity_resolver(self.scheme_id()) - } - - fn request_signer(&self) -> &dyn HttpRequestSigner { - &self.signer - } - } - - /// Type of SigV4 signature. - #[derive(Debug, Eq, PartialEq, Clone, Copy)] - pub enum HttpSignatureType { - /// A signature for a full http request should be computed, with header updates applied to the signing result. - HttpRequestHeaders, - - /// A signature for a full http request should be computed, with query param updates applied to the signing result. - /// - /// This is typically used for presigned URLs. - HttpRequestQueryParams, - } - - /// Signing options for SigV4. - #[derive(Clone, Debug, Eq, PartialEq)] - #[non_exhaustive] - pub struct SigningOptions { - /// Apply URI encoding twice. - pub double_uri_encode: bool, - /// Apply a SHA-256 payload checksum. - pub content_sha256_header: bool, - /// Normalize the URI path before signing. - pub normalize_uri_path: bool, - /// Omit the session token from the signature. - pub omit_session_token: bool, - /// Optional override for the payload to be used in signing. - pub payload_override: Option>, - /// Signature type. - pub signature_type: HttpSignatureType, - /// Whether or not the signature is optional. - pub signing_optional: bool, - /// Optional expiration (for presigning) - pub expires_in: Option, - } - - impl Default for SigningOptions { - fn default() -> Self { - Self { - double_uri_encode: true, - content_sha256_header: false, - normalize_uri_path: true, - omit_session_token: false, - payload_override: None, - signature_type: HttpSignatureType::HttpRequestHeaders, - signing_optional: false, - expires_in: None, - } - } - } - - /// SigV4 signing configuration for an operation - /// - /// Although these fields MAY be customized on a per request basis, they are generally static - /// for a given operation - #[derive(Clone, Debug, PartialEq, Eq)] - pub struct SigV4OperationSigningConfig { - /// AWS Region to sign for. - pub region: Option, - /// AWS Service to sign for. - pub service: Option, - /// Signing options. - pub signing_options: SigningOptions, - } - - /// SigV4 HTTP request signer. - #[derive(Debug, Default)] - pub struct SigV4HttpRequestSigner; - - impl SigV4HttpRequestSigner { - /// Creates a new signer instance. - pub fn new() -> Self { - Self - } - - fn settings(operation_config: &SigV4OperationSigningConfig) -> SigningSettings { - let mut settings = SigningSettings::default(); - settings.percent_encoding_mode = if operation_config.signing_options.double_uri_encode { - PercentEncodingMode::Double - } else { - PercentEncodingMode::Single - }; - settings.payload_checksum_kind = - if operation_config.signing_options.content_sha256_header { - PayloadChecksumKind::XAmzSha256 - } else { - PayloadChecksumKind::NoHeader - }; - settings.uri_path_normalization_mode = - if operation_config.signing_options.normalize_uri_path { - UriPathNormalizationMode::Enabled - } else { - UriPathNormalizationMode::Disabled - }; - settings.session_token_mode = if operation_config.signing_options.omit_session_token { - SessionTokenMode::Exclude - } else { - SessionTokenMode::Include - }; - settings.signature_location = match operation_config.signing_options.signature_type { - HttpSignatureType::HttpRequestHeaders => SignatureLocation::Headers, - HttpSignatureType::HttpRequestQueryParams => SignatureLocation::QueryParams, - }; - settings.expires_in = operation_config.signing_options.expires_in; - settings - } - - fn signing_params<'a>( - settings: SigningSettings, - credentials: &'a Credentials, - operation_config: &'a SigV4OperationSigningConfig, - request_timestamp: SystemTime, - ) -> Result, SigV4SigningError> { - if let Some(expires_in) = settings.expires_in { - if let Some(creds_expires_time) = credentials.expiry() { - let presigned_expires_time = request_timestamp + expires_in; - if presigned_expires_time > creds_expires_time { - tracing::warn!(EXPIRATION_WARNING); - } - } - } - - let mut builder = SigningParams::builder() - .access_key(credentials.access_key_id()) - .secret_key(credentials.secret_access_key()) - .region( - operation_config - .region - .as_ref() - .ok_or(SigV4SigningError::MissingSigningRegion)? - .as_ref(), - ) - .service_name( - operation_config - .service - .as_ref() - .ok_or(SigV4SigningError::MissingSigningService)? - .as_ref(), - ) - .time(request_timestamp) - .settings(settings); - builder.set_security_token(credentials.session_token()); - Ok(builder.build().expect("all required fields set")) - } - - fn extract_operation_config<'a>( - auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'a>, - config_bag: &'a ConfigBag, - ) -> Result, SigV4SigningError> { - let operation_config = config_bag - .get::() - .ok_or(SigV4SigningError::MissingOperationSigningConfig)?; - - let EndpointAuthSchemeConfig { - signing_region_override, - signing_service_override, - } = Self::extract_endpoint_auth_scheme_config(auth_scheme_endpoint_config)?; - - match (signing_region_override, signing_service_override) { - (None, None) => Ok(Cow::Borrowed(operation_config)), - (region, service) => { - let mut operation_config = operation_config.clone(); - if region.is_some() { - operation_config.region = region; - } - if service.is_some() { - operation_config.service = service; - } - Ok(Cow::Owned(operation_config)) - } - } - } - - fn extract_endpoint_auth_scheme_config( - endpoint_config: AuthSchemeEndpointConfig<'_>, - ) -> Result { - let (mut signing_region_override, mut signing_service_override) = (None, None); - if let Some(config) = endpoint_config.config().and_then(Document::as_object) { - use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType; - signing_region_override = match config.get("signingRegion") { - Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))), - None => None, - _ => return Err(UnexpectedType("signingRegion")), - }; - signing_service_override = match config.get("signingName") { - Some(Document::String(s)) => Some(SigningService::from(s.to_string())), - None => None, - _ => return Err(UnexpectedType("signingName")), - }; - } - Ok(EndpointAuthSchemeConfig { - signing_region_override, - signing_service_override, - }) - } - } - - impl HttpRequestSigner for SigV4HttpRequestSigner { - fn sign_request( - &self, - request: &mut HttpRequest, - identity: &Identity, - auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>, - config_bag: &ConfigBag, - ) -> Result<(), BoxError> { - let operation_config = - Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?; - let request_time = config_bag.request_time().unwrap_or_default().system_time(); - - let credentials = if let Some(creds) = identity.data::() { - creds - } else if operation_config.signing_options.signing_optional { - tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials"); - return Ok(()); - } else { - return Err(SigV4SigningError::WrongIdentityType(identity.clone()).into()); - }; - - let settings = Self::settings(&operation_config); - let signing_params = - Self::signing_params(settings, credentials, &operation_config, request_time)?; - - let (signing_instructions, _signature) = { - // A body that is already in memory can be signed directly. A body that is not in memory - // (any sort of streaming body or presigned request) will be signed via UNSIGNED-PAYLOAD. - let signable_body = operation_config - .signing_options - .payload_override - .as_ref() - // the payload_override is a cheap clone because it contains either a - // reference or a short checksum (we're not cloning the entire body) - .cloned() - .unwrap_or_else(|| { - request - .body() - .bytes() - .map(SignableBody::Bytes) - .unwrap_or(SignableBody::UnsignedPayload) - }); - - let signable_request = SignableRequest::new( - request.method(), - request.uri(), - request.headers(), - signable_body, - ); - sign(signable_request, &signing_params)? - } - .into_parts(); - - signing_instructions.apply_to_request(request); - Ok(()) - } - } - - #[cfg(test)] - mod tests { - use super::*; - use aws_credential_types::Credentials; - use aws_sigv4::http_request::SigningSettings; - use aws_types::region::SigningRegion; - use aws_types::SigningService; - use std::collections::HashMap; - use std::time::{Duration, SystemTime}; - use tracing_test::traced_test; - - #[test] - #[traced_test] - fn expiration_warning() { - let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1000); - let creds_expire_in = Duration::from_secs(100); - - let mut settings = SigningSettings::default(); - settings.expires_in = Some(creds_expire_in - Duration::from_secs(10)); - - let credentials = Credentials::new( - "test-access-key", - "test-secret-key", - Some("test-session-token".into()), - Some(now + creds_expire_in), - "test", - ); - let operation_config = SigV4OperationSigningConfig { - region: Some(SigningRegion::from_static("test")), - service: Some(SigningService::from_static("test")), - signing_options: SigningOptions { - double_uri_encode: true, - content_sha256_header: true, - normalize_uri_path: true, - omit_session_token: true, - signature_type: HttpSignatureType::HttpRequestHeaders, - signing_optional: false, - expires_in: None, - payload_override: None, - }, - }; - SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now) - .unwrap(); - assert!(!logs_contain(EXPIRATION_WARNING)); - - let mut settings = SigningSettings::default(); - settings.expires_in = Some(creds_expire_in + Duration::from_secs(10)); - - SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now) - .unwrap(); - assert!(logs_contain(EXPIRATION_WARNING)); - } - - #[test] - fn endpoint_config_overrides_region_and_service() { - let mut cfg = ConfigBag::base(); - cfg.put(SigV4OperationSigningConfig { - region: Some(SigningRegion::from(Region::new("override-this-region"))), - service: Some(SigningService::from_static("override-this-service")), - signing_options: Default::default(), - }); - let config = Document::Object({ - let mut out = HashMap::new(); - out.insert("name".to_string(), "sigv4".to_string().into()); - out.insert( - "signingName".to_string(), - "qldb-override".to_string().into(), - ); - out.insert( - "signingRegion".to_string(), - "us-east-override".to_string().into(), - ); - out - }); - let config = AuthSchemeEndpointConfig::new(Some(&config)); - - let result = - SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success"); - - assert_eq!( - result.region, - Some(SigningRegion::from(Region::new("us-east-override"))) - ); - assert_eq!( - result.service, - Some(SigningService::from_static("qldb-override")) - ); - assert!(matches!(result, Cow::Owned(_))); - } - - #[test] - fn endpoint_config_supports_fallback_when_region_or_service_are_unset() { - let mut cfg = ConfigBag::base(); - cfg.put(SigV4OperationSigningConfig { - region: Some(SigningRegion::from(Region::new("us-east-1"))), - service: Some(SigningService::from_static("qldb")), - signing_options: Default::default(), - }); - let config = AuthSchemeEndpointConfig::empty(); - - let result = - SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success"); - - assert_eq!( - result.region, - Some(SigningRegion::from(Region::new("us-east-1"))) - ); - assert_eq!(result.service, Some(SigningService::from_static("qldb"))); - assert!(matches!(result, Cow::Borrowed(_))); - } - } -} +pub mod sigv4; diff --git a/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs b/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs new file mode 100644 index 000000000..8d7255684 --- /dev/null +++ b/aws/rust-runtime/aws-runtime/src/auth/sigv4.rs @@ -0,0 +1,609 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_credential_types::Credentials; +use aws_sigv4::http_request::{ + sign, PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableBody, + SignableRequest, SignatureLocation, SigningParams, SigningSettings, UriPathNormalizationMode, +}; +use aws_smithy_runtime_api::client::auth::{ + AuthSchemeEndpointConfig, AuthSchemeId, HttpAuthScheme, HttpRequestSigner, +}; +use aws_smithy_runtime_api::client::identity::{Identity, IdentityResolver, IdentityResolvers}; +use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors, HttpRequest}; +use aws_smithy_runtime_api::config_bag::ConfigBag; +use aws_smithy_types::Document; +use aws_types::region::{Region, SigningRegion}; +use aws_types::SigningService; +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt; +use std::time::{Duration, SystemTime}; + +const EXPIRATION_WARNING: &str = "Presigned request will expire before the given \ + `expires_in` duration because the credentials used to sign it will expire first."; + +/// Auth scheme ID for SigV4. +pub const SCHEME_ID: AuthSchemeId = AuthSchemeId::new("sigv4"); + +struct EndpointAuthSchemeConfig { + signing_region_override: Option, + signing_service_override: Option, +} + +#[derive(Debug)] +enum SigV4SigningError { + MissingOperationSigningConfig, + MissingSigningRegion, + MissingSigningService, + WrongIdentityType(Identity), + BadTypeInEndpointAuthSchemeConfig(&'static str), +} + +impl fmt::Display for SigV4SigningError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use SigV4SigningError::*; + let mut w = |s| f.write_str(s); + match self { + MissingOperationSigningConfig => w("missing operation signing config for SigV4"), + MissingSigningRegion => w("missing signing region for SigV4 signing"), + MissingSigningService => w("missing signing service for SigV4 signing"), + WrongIdentityType(identity) => { + write!(f, "wrong identity type for SigV4: {identity:?}") + } + BadTypeInEndpointAuthSchemeConfig(field_name) => { + write!( + f, + "unexpected type for `{field_name}` in endpoint auth scheme config", + ) + } + } + } +} + +impl StdError for SigV4SigningError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + match self { + Self::MissingOperationSigningConfig => None, + Self::MissingSigningRegion => None, + Self::MissingSigningService => None, + Self::WrongIdentityType(_) => None, + Self::BadTypeInEndpointAuthSchemeConfig(_) => None, + } + } +} + +/// SigV4 auth scheme. +#[derive(Debug, Default)] +pub struct SigV4HttpAuthScheme { + signer: SigV4HttpRequestSigner, +} + +impl SigV4HttpAuthScheme { + /// Creates a new `SigV4HttpAuthScheme`. + pub fn new() -> Self { + Default::default() + } +} + +impl HttpAuthScheme for SigV4HttpAuthScheme { + fn scheme_id(&self) -> AuthSchemeId { + SCHEME_ID + } + + fn identity_resolver<'a>( + &self, + identity_resolvers: &'a IdentityResolvers, + ) -> Option<&'a dyn IdentityResolver> { + identity_resolvers.identity_resolver(self.scheme_id()) + } + + fn request_signer(&self) -> &dyn HttpRequestSigner { + &self.signer + } +} + +/// Type of SigV4 signature. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum HttpSignatureType { + /// A signature for a full http request should be computed, with header updates applied to the signing result. + HttpRequestHeaders, + + /// A signature for a full http request should be computed, with query param updates applied to the signing result. + /// + /// This is typically used for presigned URLs. + HttpRequestQueryParams, +} + +/// Signing options for SigV4. +#[derive(Clone, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub struct SigningOptions { + /// Apply URI encoding twice. + pub double_uri_encode: bool, + /// Apply a SHA-256 payload checksum. + pub content_sha256_header: bool, + /// Normalize the URI path before signing. + pub normalize_uri_path: bool, + /// Omit the session token from the signature. + pub omit_session_token: bool, + /// Optional override for the payload to be used in signing. + pub payload_override: Option>, + /// Signature type. + pub signature_type: HttpSignatureType, + /// Whether or not the signature is optional. + pub signing_optional: bool, + /// Optional expiration (for presigning) + pub expires_in: Option, +} + +impl Default for SigningOptions { + fn default() -> Self { + Self { + double_uri_encode: true, + content_sha256_header: false, + normalize_uri_path: true, + omit_session_token: false, + payload_override: None, + signature_type: HttpSignatureType::HttpRequestHeaders, + signing_optional: false, + expires_in: None, + } + } +} + +/// SigV4 signing configuration for an operation +/// +/// Although these fields MAY be customized on a per request basis, they are generally static +/// for a given operation +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SigV4OperationSigningConfig { + /// AWS Region to sign for. + pub region: Option, + /// AWS Service to sign for. + pub service: Option, + /// Signing options. + pub signing_options: SigningOptions, +} + +/// SigV4 HTTP request signer. +#[derive(Debug, Default)] +pub struct SigV4HttpRequestSigner; + +impl SigV4HttpRequestSigner { + /// Creates a new signer instance. + pub fn new() -> Self { + Self + } + + fn settings(operation_config: &SigV4OperationSigningConfig) -> SigningSettings { + let mut settings = SigningSettings::default(); + settings.percent_encoding_mode = if operation_config.signing_options.double_uri_encode { + PercentEncodingMode::Double + } else { + PercentEncodingMode::Single + }; + settings.payload_checksum_kind = if operation_config.signing_options.content_sha256_header { + PayloadChecksumKind::XAmzSha256 + } else { + PayloadChecksumKind::NoHeader + }; + settings.uri_path_normalization_mode = + if operation_config.signing_options.normalize_uri_path { + UriPathNormalizationMode::Enabled + } else { + UriPathNormalizationMode::Disabled + }; + settings.session_token_mode = if operation_config.signing_options.omit_session_token { + SessionTokenMode::Exclude + } else { + SessionTokenMode::Include + }; + settings.signature_location = match operation_config.signing_options.signature_type { + HttpSignatureType::HttpRequestHeaders => SignatureLocation::Headers, + HttpSignatureType::HttpRequestQueryParams => SignatureLocation::QueryParams, + }; + settings.expires_in = operation_config.signing_options.expires_in; + settings + } + + fn signing_params<'a>( + settings: SigningSettings, + credentials: &'a Credentials, + operation_config: &'a SigV4OperationSigningConfig, + request_timestamp: SystemTime, + ) -> Result, SigV4SigningError> { + if let Some(expires_in) = settings.expires_in { + if let Some(creds_expires_time) = credentials.expiry() { + let presigned_expires_time = request_timestamp + expires_in; + if presigned_expires_time > creds_expires_time { + tracing::warn!(EXPIRATION_WARNING); + } + } + } + + let mut builder = SigningParams::builder() + .access_key(credentials.access_key_id()) + .secret_key(credentials.secret_access_key()) + .region( + operation_config + .region + .as_ref() + .ok_or(SigV4SigningError::MissingSigningRegion)? + .as_ref(), + ) + .service_name( + operation_config + .service + .as_ref() + .ok_or(SigV4SigningError::MissingSigningService)? + .as_ref(), + ) + .time(request_timestamp) + .settings(settings); + builder.set_security_token(credentials.session_token()); + Ok(builder.build().expect("all required fields set")) + } + + fn extract_operation_config<'a>( + auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'a>, + config_bag: &'a ConfigBag, + ) -> Result, SigV4SigningError> { + let operation_config = config_bag + .get::() + .ok_or(SigV4SigningError::MissingOperationSigningConfig)?; + + let EndpointAuthSchemeConfig { + signing_region_override, + signing_service_override, + } = Self::extract_endpoint_auth_scheme_config(auth_scheme_endpoint_config)?; + + match (signing_region_override, signing_service_override) { + (None, None) => Ok(Cow::Borrowed(operation_config)), + (region, service) => { + let mut operation_config = operation_config.clone(); + if region.is_some() { + operation_config.region = region; + } + if service.is_some() { + operation_config.service = service; + } + Ok(Cow::Owned(operation_config)) + } + } + } + + fn extract_endpoint_auth_scheme_config( + endpoint_config: AuthSchemeEndpointConfig<'_>, + ) -> Result { + let (mut signing_region_override, mut signing_service_override) = (None, None); + if let Some(config) = endpoint_config.config().and_then(Document::as_object) { + use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType; + signing_region_override = match config.get("signingRegion") { + Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))), + None => None, + _ => return Err(UnexpectedType("signingRegion")), + }; + signing_service_override = match config.get("signingName") { + Some(Document::String(s)) => Some(SigningService::from(s.to_string())), + None => None, + _ => return Err(UnexpectedType("signingName")), + }; + } + Ok(EndpointAuthSchemeConfig { + signing_region_override, + signing_service_override, + }) + } +} + +impl HttpRequestSigner for SigV4HttpRequestSigner { + fn sign_request( + &self, + request: &mut HttpRequest, + identity: &Identity, + auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>, + config_bag: &ConfigBag, + ) -> Result<(), BoxError> { + let operation_config = + Self::extract_operation_config(auth_scheme_endpoint_config, config_bag)?; + let request_time = config_bag.request_time().unwrap_or_default().system_time(); + + let credentials = if let Some(creds) = identity.data::() { + creds + } else if operation_config.signing_options.signing_optional { + tracing::debug!("skipped SigV4 signing since signing is optional for this operation and there are no credentials"); + return Ok(()); + } else { + return Err(SigV4SigningError::WrongIdentityType(identity.clone()).into()); + }; + + let settings = Self::settings(&operation_config); + let signing_params = + Self::signing_params(settings, credentials, &operation_config, request_time)?; + + let (signing_instructions, _signature) = { + // A body that is already in memory can be signed directly. A body that is not in memory + // (any sort of streaming body or presigned request) will be signed via UNSIGNED-PAYLOAD. + let signable_body = operation_config + .signing_options + .payload_override + .as_ref() + // the payload_override is a cheap clone because it contains either a + // reference or a short checksum (we're not cloning the entire body) + .cloned() + .unwrap_or_else(|| { + request + .body() + .bytes() + .map(SignableBody::Bytes) + .unwrap_or(SignableBody::UnsignedPayload) + }); + + let signable_request = SignableRequest::new( + request.method(), + request.uri(), + request.headers(), + signable_body, + ); + sign(signable_request, &signing_params)? + } + .into_parts(); + + // If this is an event stream operation, set up the event stream signer + #[cfg(feature = "event-stream")] + { + use aws_smithy_eventstream::frame::DeferredSignerSender; + use aws_smithy_runtime_api::client::orchestrator::RequestTime; + use event_stream::SigV4MessageSigner; + + if let Some(signer_sender) = config_bag.get::() { + let time_override = config_bag.get::().copied(); + signer_sender + .send(Box::new(SigV4MessageSigner::new( + _signature, + credentials.clone(), + Region::new(signing_params.region().to_string()).into(), + signing_params.service_name().to_string().into(), + time_override, + )) as _) + .expect("failed to send deferred signer"); + } + } + + signing_instructions.apply_to_request(request); + Ok(()) + } +} + +#[cfg(feature = "event-stream")] +mod event_stream { + use aws_credential_types::Credentials; + use aws_sigv4::event_stream::{sign_empty_message, sign_message}; + use aws_sigv4::SigningParams; + use aws_smithy_eventstream::frame::{Message, SignMessage, SignMessageError}; + use aws_smithy_runtime_api::client::orchestrator::RequestTime; + use aws_types::region::SigningRegion; + use aws_types::SigningService; + + /// Event Stream SigV4 signing implementation. + #[derive(Debug)] + pub(super) struct SigV4MessageSigner { + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, + } + + impl SigV4MessageSigner { + pub(super) fn new( + last_signature: String, + credentials: Credentials, + signing_region: SigningRegion, + signing_service: SigningService, + time: Option, + ) -> Self { + Self { + last_signature, + credentials, + signing_region, + signing_service, + time, + } + } + + fn signing_params(&self) -> SigningParams<'_, ()> { + let mut builder = SigningParams::builder() + .access_key(self.credentials.access_key_id()) + .secret_key(self.credentials.secret_access_key()) + .region(self.signing_region.as_ref()) + .service_name(self.signing_service.as_ref()) + .time(self.time.unwrap_or_default().system_time()) + .settings(()); + builder.set_security_token(self.credentials.session_token()); + builder.build().unwrap() + } + } + + impl SignMessage for SigV4MessageSigner { + fn sign(&mut self, message: Message) -> Result { + let (signed_message, signature) = { + let params = self.signing_params(); + sign_message(&message, &self.last_signature, ¶ms).into_parts() + }; + self.last_signature = signature; + Ok(signed_message) + } + + fn sign_empty(&mut self) -> Option> { + let (signed_message, signature) = { + let params = self.signing_params(); + sign_empty_message(&self.last_signature, ¶ms).into_parts() + }; + self.last_signature = signature; + Some(Ok(signed_message)) + } + } + + #[cfg(test)] + mod tests { + use super::*; + use aws_credential_types::Credentials; + use aws_smithy_eventstream::frame::{HeaderValue, Message, SignMessage}; + use aws_types::region::Region; + use aws_types::region::SigningRegion; + use aws_types::SigningService; + use std::time::{Duration, UNIX_EPOCH}; + + fn check_send_sync(value: T) -> T { + value + } + + #[test] + fn sign_message() { + let region = Region::new("us-east-1"); + let mut signer = check_send_sync(SigV4MessageSigner::new( + "initial-signature".into(), + Credentials::for_tests(), + SigningRegion::from(region), + SigningService::from_static("transcribe"), + Some(RequestTime::new(UNIX_EPOCH + Duration::new(1611160427, 0))), + )); + let mut signatures = Vec::new(); + for _ in 0..5 { + let signed = signer + .sign(Message::new(&b"identical message"[..])) + .unwrap(); + if let HeaderValue::ByteArray(signature) = signed + .headers() + .iter() + .find(|h| h.name().as_str() == ":chunk-signature") + .unwrap() + .value() + { + signatures.push(signature.clone()); + } else { + panic!("failed to get the :chunk-signature") + } + } + for i in 1..signatures.len() { + assert_ne!(signatures[i - 1], signatures[i]); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_credential_types::Credentials; + use aws_sigv4::http_request::SigningSettings; + use aws_types::region::SigningRegion; + use aws_types::SigningService; + use std::collections::HashMap; + use std::time::{Duration, SystemTime}; + use tracing_test::traced_test; + + #[test] + #[traced_test] + fn expiration_warning() { + let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1000); + let creds_expire_in = Duration::from_secs(100); + + let mut settings = SigningSettings::default(); + settings.expires_in = Some(creds_expire_in - Duration::from_secs(10)); + + let credentials = Credentials::new( + "test-access-key", + "test-secret-key", + Some("test-session-token".into()), + Some(now + creds_expire_in), + "test", + ); + let operation_config = SigV4OperationSigningConfig { + region: Some(SigningRegion::from_static("test")), + service: Some(SigningService::from_static("test")), + signing_options: SigningOptions { + double_uri_encode: true, + content_sha256_header: true, + normalize_uri_path: true, + omit_session_token: true, + signature_type: HttpSignatureType::HttpRequestHeaders, + signing_optional: false, + expires_in: None, + payload_override: None, + }, + }; + SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now) + .unwrap(); + assert!(!logs_contain(EXPIRATION_WARNING)); + + let mut settings = SigningSettings::default(); + settings.expires_in = Some(creds_expire_in + Duration::from_secs(10)); + + SigV4HttpRequestSigner::signing_params(settings, &credentials, &operation_config, now) + .unwrap(); + assert!(logs_contain(EXPIRATION_WARNING)); + } + + #[test] + fn endpoint_config_overrides_region_and_service() { + let mut cfg = ConfigBag::base(); + cfg.put(SigV4OperationSigningConfig { + region: Some(SigningRegion::from(Region::new("override-this-region"))), + service: Some(SigningService::from_static("override-this-service")), + signing_options: Default::default(), + }); + let config = Document::Object({ + let mut out = HashMap::new(); + out.insert("name".to_string(), "sigv4".to_string().into()); + out.insert( + "signingName".to_string(), + "qldb-override".to_string().into(), + ); + out.insert( + "signingRegion".to_string(), + "us-east-override".to_string().into(), + ); + out + }); + let config = AuthSchemeEndpointConfig::new(Some(&config)); + + let result = + SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success"); + + assert_eq!( + result.region, + Some(SigningRegion::from(Region::new("us-east-override"))) + ); + assert_eq!( + result.service, + Some(SigningService::from_static("qldb-override")) + ); + assert!(matches!(result, Cow::Owned(_))); + } + + #[test] + fn endpoint_config_supports_fallback_when_region_or_service_are_unset() { + let mut cfg = ConfigBag::base(); + cfg.put(SigV4OperationSigningConfig { + region: Some(SigningRegion::from(Region::new("us-east-1"))), + service: Some(SigningService::from_static("qldb")), + signing_options: Default::default(), + }); + let config = AuthSchemeEndpointConfig::empty(); + + let result = + SigV4HttpRequestSigner::extract_operation_config(config, &cfg).expect("success"); + + assert_eq!( + result.region, + Some(SigningRegion::from(Region::new("us-east-1"))) + ); + assert_eq!(result.service, Some(SigningService::from_static("qldb"))); + assert!(matches!(result, Cow::Borrowed(_))); + } +} diff --git a/aws/rust-runtime/aws-sigv4/src/lib.rs b/aws/rust-runtime/aws-sigv4/src/lib.rs index 1d3b36ade..d20df2399 100644 --- a/aws/rust-runtime/aws-sigv4/src/lib.rs +++ b/aws/rust-runtime/aws-sigv4/src/lib.rs @@ -49,6 +49,18 @@ pub struct SigningParams<'a, S> { pub(crate) settings: S, } +impl<'a, S> SigningParams<'a, S> { + /// Returns the region that will be used to sign + pub fn region(&self) -> &str { + self.region + } + + /// Returns the service name that will be used to sign + pub fn service_name(&self) -> &str { + self.service_name + } +} + impl<'a, S: fmt::Debug> fmt::Debug for SigningParams<'a, S> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SigningParams") diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt index 3ca470fd0..dbf6710a4 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecorator.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isInputEventStream import software.amazon.smithy.rust.codegen.core.util.letIf @@ -47,7 +48,7 @@ class SigV4AuthDecorator : ClientCodegenDecorator { } } -private class AuthServiceRuntimePluginCustomization(codegenContext: ClientCodegenContext) : +private class AuthServiceRuntimePluginCustomization(private val codegenContext: ClientCodegenContext) : ServiceRuntimePluginCustomization() { private val runtimeConfig = codegenContext.runtimeConfig private val codegenScope by lazy { @@ -72,6 +73,11 @@ private class AuthServiceRuntimePluginCustomization(codegenContext: ClientCodege } is ServiceRuntimePluginSection.AdditionalConfig -> { + val serviceHasEventStream = codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model) + if (serviceHasEventStream) { + // enable the aws-runtime `sign-eventstream` feature + addDependency(AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType().toSymbol()) + } section.putConfigValue(this) { rustTemplate("#{SigningService}::from_static(self.handle.conf.signing_service())", *codegenScope) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt index 8f2dcda29..ec806de95 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientAdditionalPayloadContext import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.docs @@ -108,7 +109,12 @@ open class MakeOperationGenerator( // Clippy warning to make the codegen a little simpler in that case. Attribute.AllowClippyUselessConversion.render(this) withBlockTemplate("let body = #{SdkBody}::from(", ");", *codegenScope) { - bodyGenerator.generatePayload(this, "self", shape) + bodyGenerator.generatePayload( + this, + "self", + shape, + ClientAdditionalPayloadContext(propertyBagAvailable = true), + ) val streamingMember = shape.inputShape(model).findStreamingMember(model) val isBlobStreaming = streamingMember != null && model.expectShape(streamingMember.target) is BlobShape if (isBlobStreaming) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt index 22870147b..90482e8e9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientAdditionalPayloadContext import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -37,6 +38,7 @@ class RequestSerializerGenerator( val orchestrator = runtimeApi.resolve("client::orchestrator") arrayOf( "BoxError" to orchestrator.resolve("BoxError"), + "ConfigBag" to runtimeApi.resolve("config_bag::ConfigBag"), "HttpRequest" to orchestrator.resolve("HttpRequest"), "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, "Input" to interceptorContext.resolve("Input"), @@ -61,7 +63,7 @@ class RequestSerializerGenerator( struct ${operationName}RequestSerializer; impl #{RequestSerializer} for ${operationName}RequestSerializer { ##[allow(unused_mut, clippy::let_and_return, clippy::needless_borrow, clippy::useless_conversion)] - fn serialize_input(&self, input: #{Input}) -> Result<#{HttpRequest}, #{BoxError}> { + fn serialize_input(&self, input: #{Input}, _cfg: &mut #{ConfigBag}) -> Result<#{HttpRequest}, #{BoxError}> { let input = #{TypedBox}::<#{ConcreteInput}>::assume_from(input).expect("correct type").unwrap(); let mut request_builder = { #{create_http_request} @@ -76,7 +78,14 @@ class RequestSerializerGenerator( "ConcreteInput" to inputSymbol, "create_http_request" to createHttpRequest(operationShape), "generate_body" to writable { - val body = writable { bodyGenerator.generatePayload(this, "input", operationShape) } + val body = writable { + bodyGenerator.generatePayload( + this, + "input", + operationShape, + ClientAdditionalPayloadContext(propertyBagAvailable = false), + ) + } val streamingMember = inputShape.findStreamingMember(codegenContext.model) val isBlobStreaming = streamingMember != null && codegenContext.model.expectShape(streamingMember.target) is BlobShape diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 69e0bdd48..6c30515ca 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.AdditionalPayloadContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -51,19 +52,25 @@ class HttpBoundProtocolGenerator( HttpBoundProtocolTraitImplGenerator(codegenContext, protocol), ) +// TODO(enableNewSmithyRuntime): Completely delete `AdditionalPayloadContext` when switching to the orchestrator +data class ClientAdditionalPayloadContext( + val propertyBagAvailable: Boolean, +) : AdditionalPayloadContext + class ClientHttpBoundProtocolPayloadGenerator( codegenContext: ClientCodegenContext, protocol: Protocol, ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( codegenContext, protocol, HttpMessageType.REQUEST, renderEventStreamBody = { writer, params -> + val propertyBagAvailable = (params.additionalPayloadContext as ClientAdditionalPayloadContext).propertyBagAvailable writer.rustTemplate( """ { let error_marshaller = #{errorMarshallerConstructorFn}(); let marshaller = #{marshallerConstructorFn}(); let (signer, signer_sender) = #{DeferredSigner}::new(); - properties.acquire_mut().insert(signer_sender); + #{insert_into_config} let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); @@ -76,6 +83,13 @@ class ClientHttpBoundProtocolPayloadGenerator( "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), "marshallerConstructorFn" to params.marshallerConstructorFn, "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, + "insert_into_config" to writable { + if (propertyBagAvailable) { + rust("properties.acquire_mut().insert(signer_sender);") + } else { + rust("_cfg.put(signer_sender);") + } + }, ) }, ) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index b03f37440..36f608916 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.AdditionalPayloadContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -35,10 +36,15 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape import java.nio.file.Path private class TestProtocolPayloadGenerator(private val body: String) : ProtocolPayloadGenerator { - override fun payloadMetadata(operationShape: OperationShape) = + override fun payloadMetadata(operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext) = ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = false) - override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { + override fun generatePayload( + writer: RustWriter, + shapeName: String, + operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext, + ) { writer.writeWithNoFormatting(body) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt index ceb385c6d..f48a99d85 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/protocol/ProtocolGenerator.kt @@ -10,6 +10,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +/** Allows for additional context to be given to the payload generator from where it is being called */ +interface AdditionalPayloadContext + /** * Payload Body Generator. * @@ -31,7 +34,10 @@ interface ProtocolPayloadGenerator { * Most operations will use the HTTP payload as a reference, but for operations that will consume the entire stream * later,they will need to take ownership and different code needs to be generated. */ - fun payloadMetadata(operationShape: OperationShape): PayloadMetadata + fun payloadMetadata( + operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext = object : AdditionalPayloadContext {}, + ): PayloadMetadata /** * Write the payload into [writer]. @@ -42,7 +48,12 @@ interface ProtocolPayloadGenerator { * - a `Vec` for non-streaming operations; or * - a `ByteStream` for streaming operations. */ - fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) + fun generatePayload( + writer: RustWriter, + shapeName: String, + operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext = object : AdditionalPayloadContext {}, + ) } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 1e17beca4..2c6ffc662 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.AdditionalPayloadContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamErrorMarshallerGenerator @@ -46,6 +47,7 @@ data class EventStreamBodyParams( val memberName: String, val marshallerConstructorFn: RuntimeType, val errorMarshallerConstructorFn: RuntimeType, + val additionalPayloadContext: AdditionalPayloadContext, ) class HttpBoundProtocolPayloadGenerator( @@ -69,7 +71,10 @@ class HttpBoundProtocolPayloadGenerator( ) private val protocolFunctions = ProtocolFunctions(codegenContext) - override fun payloadMetadata(operationShape: OperationShape): ProtocolPayloadGenerator.PayloadMetadata { + override fun payloadMetadata( + operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext, + ): ProtocolPayloadGenerator.PayloadMetadata { val (shape, payloadMemberName) = when (httpMessageType) { HttpMessageType.RESPONSE -> operationShape.outputShape(model) to httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName @@ -98,32 +103,43 @@ class HttpBoundProtocolPayloadGenerator( } } - override fun generatePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { + override fun generatePayload( + writer: RustWriter, + shapeName: String, + operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext, + ) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape) - HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape) + HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape, additionalPayloadContext) + HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape, additionalPayloadContext) } } - private fun generateRequestPayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { + private fun generateRequestPayload( + writer: RustWriter, shapeName: String, operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext, + ) { val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) } } - private fun generateResponsePayload(writer: RustWriter, shapeName: String, operationShape: OperationShape) { + private fun generateResponsePayload( + writer: RustWriter, shapeName: String, operationShape: OperationShape, + additionalPayloadContext: AdditionalPayloadContext, + ) { val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName) + generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) } } @@ -132,16 +148,17 @@ class HttpBoundProtocolPayloadGenerator( shapeName: String, operationShape: OperationShape, payloadMemberName: String, + additionalPayloadContext: AdditionalPayloadContext, ) { val serializerGenerator = protocol.structuredDataSerializer() if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName) + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName, additionalPayloadContext) } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output") + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output", additionalPayloadContext) } else { throw CodegenException("Payload serializer for event streams with an invalid configuration") } @@ -171,6 +188,7 @@ class HttpBoundProtocolPayloadGenerator( memberShape: MemberShape, serializerGenerator: StructuredDataSerializerGenerator, outerName: String, + additionalPayloadContext: AdditionalPayloadContext, ) { val memberName = symbolProvider.toMemberName(memberShape) val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) @@ -207,6 +225,7 @@ class HttpBoundProtocolPayloadGenerator( memberName, marshallerConstructorFn, errorMarshallerConstructorFn, + additionalPayloadContext, ), ) } diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index a4faa3f23..d19690e72 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -17,12 +17,14 @@ use tracing::trace; /// Input type for Event Streams. pub struct EventStreamSender { - input_stream: Pin> + Send>>, + input_stream: Pin> + Send + Sync>>, } impl Debug for EventStreamSender { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "EventStreamSender(Box)") + let name_t = std::any::type_name::(); + let name_e = std::any::type_name::(); + write!(f, "EventStreamSender<{name_t}, {name_e}>") } } @@ -40,7 +42,7 @@ impl EventStreamSender { impl From for EventStreamSender where - S: Stream> + Send + 'static, + S: Stream> + Send + Sync + 'static, { fn from(stream: S) -> Self { EventStreamSender { @@ -260,6 +262,17 @@ mod tests { } } + fn check_send_sync(value: T) -> T { + value + } + + #[test] + fn event_stream_sender_send_sync() { + check_send_sync(EventStreamSender::from(stream! { + yield Result::<_, SignMessageError>::Ok(TestMessage("test".into())); + })); + } + fn check_compatible_with_hyper_wrap_stream(stream: S) -> S where S: Stream> + Send + 'static, diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index 3b9aaff2f..3999c365b 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -33,7 +33,7 @@ pub type BoxFuture = Pin> + Se pub type Future = NowOrLater, BoxFuture>; pub trait RequestSerializer: Send + Sync + fmt::Debug { - fn serialize_input(&self, input: Input) -> Result; + fn serialize_input(&self, input: Input, cfg: &mut ConfigBag) -> Result; } pub trait ResponseDeserializer: Send + Sync + fmt::Debug { @@ -146,7 +146,7 @@ pub trait ConfigBagAccessors { fn http_auth_schemes(&self) -> &HttpAuthSchemes; fn set_http_auth_schemes(&mut self, http_auth_schemes: HttpAuthSchemes); - fn request_serializer(&self) -> &dyn RequestSerializer; + fn request_serializer(&self) -> Arc; fn set_request_serializer(&mut self, request_serializer: impl RequestSerializer + 'static); fn response_deserializer(&self) -> &dyn ResponseDeserializer; @@ -246,14 +246,14 @@ impl ConfigBagAccessors for ConfigBag { self.put::(http_auth_schemes); } - fn request_serializer(&self) -> &dyn RequestSerializer { - &**self - .get::>() + fn request_serializer(&self) -> Arc { + self.get::>() .expect("missing request serializer") + .clone() } fn set_request_serializer(&mut self, request_serializer: impl RequestSerializer + 'static) { - self.put::>(Box::new(request_serializer)); + self.put::>(Arc::new(request_serializer)); } fn response_deserializer(&self) -> &dyn ResponseDeserializer { diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index 51cd3ee71..75bc21b50 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -107,7 +107,7 @@ async fn try_op(ctx: &mut InterceptorContext, cfg: &mut ConfigBag, interceptors: { let request_serializer = cfg.request_serializer(); let input = ctx.take_input().expect("input set at this point"); - let request = halt_on_err!([ctx] => request_serializer.serialize_input(input)); + let request = halt_on_err!([ctx] => request_serializer.serialize_input(input, cfg)); ctx.set_request(request); } diff --git a/rust-runtime/aws-smithy-runtime/src/client/test_util/serializer.rs b/rust-runtime/aws-smithy-runtime/src/client/test_util/serializer.rs index d9a265978..9895fd2ca 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/test_util/serializer.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/test_util/serializer.rs @@ -39,7 +39,11 @@ impl CannedRequestSerializer { } impl RequestSerializer for CannedRequestSerializer { - fn serialize_input(&self, _input: Input) -> Result { + fn serialize_input( + &self, + _input: Input, + _cfg: &mut ConfigBag, + ) -> Result { let req = self .take() .ok_or("CannedRequestSerializer's inner value has already been taken.")?; -- GitLab