From 9d3bad7e6e3ae8a46f35a8df495c36d101eaa18d Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 7 Oct 2021 08:37:51 -0700 Subject: [PATCH] Fix bugs in Event Stream unmarshalling (#736) * Fix bugs in Event Stream unmarshalling - S3 Select's `Cont` event doesn't come across with a `:content-type` header, but the unmarshaller was expecting this header. - The unmarshaller attempted to parse events that were modeled as empty structs, but S3 sends an empty payload for empty structs rather than an empty XML payload conforming to restXml. - The event stream `Receiver` was returning that there were no more events even when it had some remaining events in its buffer. * Update changelogs * Incorporate feedback * Split `smithy_http::event_stream` into multiple modules * Add another test case to Receiver * Fix CI Co-authored-by: Russell Cohen --- CHANGELOG.md | 2 + aws/SDK_CHANGELOG.md | 2 + .../s3/tests/select-object-content.rs | 1 + .../rust/codegen/rustlang/CargoDependency.kt | 1 + .../rust/codegen/rustlang/RustWriter.kt | 2 +- .../parse/EventStreamUnmarshallerGenerator.kt | 79 ++- rust-runtime/smithy-eventstream/src/smithy.rs | 53 +- rust-runtime/smithy-http/Cargo.toml | 2 +- rust-runtime/smithy-http/src/event_stream.rs | 511 +---------------- .../smithy-http/src/event_stream/input.rs | 257 +++++++++ .../smithy-http/src/event_stream/output.rs | 533 ++++++++++++++++++ 11 files changed, 902 insertions(+), 541 deletions(-) create mode 100644 rust-runtime/smithy-http/src/event_stream/input.rs create mode 100644 rust-runtime/smithy-http/src/event_stream/output.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 91314d2e6..eefe50612 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ vNext (Month Day, Year) **New this week** - :bug: Fix an issue where `smithy-xml` may have generated invalid XML (smithy-rs#719) +- :bug: Fix error when receiving empty event stream messages (smithy-rs#736) +- :bug: Fix bug in event stream receiver that could cause the last events in the response stream to be lost (smithy-rs#736) v0.24 (September 24th, 2021) ============================ diff --git a/aws/SDK_CHANGELOG.md b/aws/SDK_CHANGELOG.md index 8cdd3ac80..651cbef86 100644 --- a/aws/SDK_CHANGELOG.md +++ b/aws/SDK_CHANGELOG.md @@ -9,6 +9,8 @@ vNext (Month Day, Year) **New This Week** - :tada: Add presigned request support and examples for S3 GetObject and PutObject (smithy-rs#731) +- :bug: Fix error when receiving `Cont` event from S3 SelectObjectContent (smithy-rs#736) +- :bug: Fix bug in event stream receiver that could cause the last events in the response stream to be lost when using S3 SelectObjectContent (smithy-rs#736) - Updated Transcribe code example to take an audio file as a command-line option and added readme. v0.0.19-alpha (September 24th, 2021) diff --git a/aws/sdk/integration-tests/s3/tests/select-object-content.rs b/aws/sdk/integration-tests/s3/tests/select-object-content.rs index 9d5ef9e8d..b75da912a 100644 --- a/aws/sdk/integration-tests/s3/tests/select-object-content.rs +++ b/aws/sdk/integration-tests/s3/tests/select-object-content.rs @@ -68,6 +68,7 @@ async fn test_success() { stats.bytes_scanned, stats.bytes_processed, stats.bytes_returned )) } + SelectObjectContentEventStream::End(_) => {} otherwise => panic!("unexpected message: {:?}", otherwise), } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt index fd4ab2fa9..44c9e07d6 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt @@ -193,6 +193,7 @@ data class CargoDependency( val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14")) val HyperWithStream: CargoDependency = Hyper.withFeature("stream") val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"), optional = true) + val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types") fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt index f683e4ca0..f057d0046 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt @@ -97,7 +97,7 @@ fun T.conditionalBlock( * Convenience wrapper that tells Intellij that the contents of this block are Rust */ fun T.rust( - @Language("Rust", prefix = "macro_rules! foo { () => {{ ", suffix = "}}}") contents: String, + @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") contents: String, vararg args: Any ) { this.write(contents.trim(), *args) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index fea6fc1c8..3226fadb1 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.model.traits.EventPayloadTrait import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate @@ -62,6 +63,7 @@ class EventStreamUnmarshallerGenerator( "Message" to RuntimeType("Message", smithyEventStream, "smithy_eventstream::frame"), "OpError" to operationErrorSymbol, "SmithyError" to RuntimeType("Error", CargoDependency.SmithyTypes(runtimeConfig), "smithy_types"), + "tracing" to CargoDependency.Tracing.asType(), "UnmarshalledMessage" to RuntimeType("UnmarshalledMessage", smithyEventStream, "smithy_eventstream::frame"), "UnmarshallMessage" to RuntimeType("UnmarshallMessage", smithyEventStream, "smithy_eventstream::frame"), ) @@ -149,41 +151,61 @@ class EventStreamUnmarshallerGenerator( private fun RustWriter.renderUnmarshallUnionMember(unionMember: MemberShape, unionStruct: StructureShape) { val unionMemberName = unionMember.memberName.toPascalCase() + val empty = unionStruct.members().isEmpty() val payloadOnly = unionStruct.members().none { it.hasTrait() || it.hasTrait() } - if (payloadOnly) { - withBlock("let parsed = ", ";") { - renderParseProtocolPayload(unionMember) + when { + // Don't attempt to parse the payload for an empty struct. The payload can be empty, or if the model was + // updated since the code was generated, it can have content that would not be understood. + empty -> { + rustTemplate( + "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(#{UnionStruct}::builder().build())))", + "Output" to unionSymbol, + "UnionStruct" to symbolProvider.toSymbol(unionStruct), + *codegenScope + ) } - rustTemplate( - "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(parsed)))", - "Output" to unionSymbol, - *codegenScope - ) - } else { - rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct)) - val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } - if (payloadMember != null) { - renderUnmarshallEventPayload(payloadMember) + payloadOnly -> { + withBlock("let parsed = ", ";") { + renderParseProtocolPayload(unionMember) + } + rustTemplate( + "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(parsed)))", + "Output" to unionSymbol, + *codegenScope + ) } - val headerMembers = unionStruct.members().filter { it.hasTrait() } - if (headerMembers.isNotEmpty()) { - rustBlock("for header in message.headers()") { - rustBlock("match header.name().as_str()") { - for (member in headerMembers) { - rustBlock("${member.memberName.dq()} => ") { - renderUnmarshallEventHeader(member) + else -> { + rust("let mut builder = #T::builder();", symbolProvider.toSymbol(unionStruct)) + val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } + if (payloadMember != null) { + renderUnmarshallEventPayload(payloadMember) + } + val headerMembers = unionStruct.members().filter { it.hasTrait() } + if (headerMembers.isNotEmpty()) { + rustBlock("for header in message.headers()") { + rustBlock("match header.name().as_str()") { + for (member in headerMembers) { + rustBlock("${member.memberName.dq()} => ") { + renderUnmarshallEventHeader(member) + } + } + rust("// Event stream protocol headers start with ':'") + rustBlock("name => if !name.starts_with(':')") { + rustTemplate( + "#{tracing}::trace!(\"Unrecognized event stream message header: {}\", name);", + *codegenScope + ) } } - rust("_ => {}") } } + rustTemplate( + "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(builder.build())))", + "Output" to unionSymbol, + *codegenScope + ) } - rustTemplate( - "Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(builder.build())))", - "Output" to unionSymbol, - *codegenScope - ) } } @@ -211,10 +233,11 @@ class EventStreamUnmarshallerGenerator( expectedContentType(target)?.also { contentType -> rustTemplate( """ - if response_headers.content_type.as_str() != ${contentType.dq()} { + let content_type = response_headers.content_type().unwrap_or_default(); + if content_type != ${contentType.dq()} { return Err(#{Error}::Unmarshalling(format!( "expected :content-type to be '$contentType', but was '{}'", - response_headers.content_type.as_str() + content_type ))) } """, diff --git a/rust-runtime/smithy-eventstream/src/smithy.rs b/rust-runtime/smithy-eventstream/src/smithy.rs index 72e83f1fd..60781eec2 100644 --- a/rust-runtime/smithy-eventstream/src/smithy.rs +++ b/rust-runtime/smithy-eventstream/src/smithy.rs @@ -32,12 +32,19 @@ expect_shape_fn!(fn expect_byte_array[ByteArray] -> Blob { bytes -> Blob::new(by expect_shape_fn!(fn expect_string[String] -> String { value -> value.as_str().into() }); expect_shape_fn!(fn expect_timestamp[Timestamp] -> Instant { value -> *value }); +#[derive(Debug)] pub struct ResponseHeaders<'a> { - pub content_type: &'a StrBytes, + pub content_type: Option<&'a StrBytes>, pub message_type: &'a StrBytes, pub smithy_type: &'a StrBytes, } +impl<'a> ResponseHeaders<'a> { + pub fn content_type(&self) -> Option<&str> { + self.content_type.map(|ct| ct.as_str()) + } +} + fn expect_header_str_value<'a>( header: Option<&'a Header>, name: &str, @@ -70,7 +77,9 @@ pub fn parse_response_headers(message: &Message) -> Result; - -/// Input type for Event Streams. -pub struct EventStreamInput { - input_stream: Pin> + Send>>, -} - -impl fmt::Debug for EventStreamInput { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "EventStreamInput(Box)") - } -} - -impl EventStreamInput { - #[doc(hidden)] - pub fn into_body_stream( - self, - marshaller: impl MarshallMessage + Send + Sync + 'static, - signer: impl SignMessage + Send + Sync + 'static, - ) -> MessageStreamAdapter { - MessageStreamAdapter::new(marshaller, signer, self.input_stream) - } -} - -impl From for EventStreamInput -where - S: Stream> + Send + 'static, -{ - fn from(stream: S) -> Self { - EventStreamInput { - input_stream: Box::pin(stream), - } - } -} - -/// Adapts a `Stream` to a signed `Stream` by using the provided -/// message marshaller and signer implementations. -/// -/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be -/// marshalled into an Event Stream frame, (e.g., if the message payload was too large). -#[pin_project] -pub struct MessageStreamAdapter { - marshaller: Box + Send + Sync>, - signer: Box, - #[pin] - stream: Pin> + Send>>, - end_signal_sent: bool, - _phantom: PhantomData, -} - -impl MessageStreamAdapter -where - E: StdError + Send + Sync + 'static, -{ - pub fn new( - marshaller: impl MarshallMessage + Send + Sync + 'static, - signer: impl SignMessage + Send + Sync + 'static, - stream: Pin> + Send>>, - ) -> Self { - MessageStreamAdapter { - marshaller: Box::new(marshaller), - signer: Box::new(signer), - stream, - end_signal_sent: false, - _phantom: Default::default(), - } - } -} - -impl Stream for MessageStreamAdapter -where - E: StdError + Send + Sync + 'static, -{ - type Item = Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - match this.stream.poll_next(cx) { - Poll::Ready(message_option) => { - if let Some(message_result) = message_option { - let message_result = - message_result.map_err(|err| SdkError::ConstructionFailure(err)); - let message = this - .marshaller - .marshall(message_result?) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; - let message = this - .signer - .sign(message) - .map_err(|err| SdkError::ConstructionFailure(err))?; - let mut buffer = Vec::new(); - message - .write_to(&mut buffer) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; - Poll::Ready(Some(Ok(Bytes::from(buffer)))) - } else if !*this.end_signal_sent { - *this.end_signal_sent = true; - let mut buffer = Vec::new(); - this.signer - .sign_empty() - .map_err(|err| SdkError::ConstructionFailure(err))? - .write_to(&mut buffer) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; - Poll::Ready(Some(Ok(Bytes::from(buffer)))) - } else { - Poll::Ready(None) - } - } - Poll::Pending => Poll::Pending, - } - } -} - -/// Receives Smithy-modeled messages out of an Event Stream. -#[derive(Debug)] -pub struct Receiver { - unmarshaller: Box + Send>, - decoder: MessageFrameDecoder, - buffer: SegmentedBuf, - body: SdkBody, - /// Event Stream has optional initial response frames an with `:message-type` of - /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an - /// initial response, then the message will be stored in `buffered_message` so that it can - /// be returned with the next call of `recv()`. - buffered_message: Option, - _phantom: PhantomData, -} - -impl Receiver { - /// Creates a new `Receiver` with the given message unmarshaller and SDK body. - pub fn new( - unmarshaller: impl UnmarshallMessage + Send + 'static, - body: SdkBody, - ) -> Self { - Receiver { - unmarshaller: Box::new(unmarshaller), - decoder: MessageFrameDecoder::new(), - buffer: SegmentedBuf::new(), - body, - buffered_message: None, - _phantom: Default::default(), - } - } - - fn unmarshall(&self, message: Message) -> Result, SdkError> { - match self.unmarshaller.unmarshall(&message) { - Ok(unmarshalled) => match unmarshalled { - UnmarshalledMessage::Event(event) => Ok(Some(event)), - UnmarshalledMessage::Error(err) => { - Err(SdkError::ServiceError { err, raw: message }) - } - }, - Err(err) => Err(SdkError::ResponseError { - raw: message, - err: Box::new(err), - }), - } - } - - async fn next_message(&mut self) -> Result, SdkError> { - let next_chunk = self - .body - .data() - .await - .transpose() - .map_err(|err| SdkError::DispatchFailure(err))?; - if let Some(chunk) = next_chunk { - // The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary - self.buffer.push(chunk); - if let DecodedFrame::Complete(message) = self - .decoder - .decode_frame(&mut self.buffer) - .map_err(|err| SdkError::DispatchFailure(Box::new(err)))? - { - return Ok(Some(message)); - } - } - Ok(None) - } - - /// Tries to receive the initial response message that has `:event-type` of `initial-response`. - /// If a different event type is received, then it is buffered and `Ok(None)` is returned. - #[doc(hidden)] - pub async fn try_recv_initial(&mut self) -> Result, SdkError> { - if let Some(message) = self.next_message().await? { - if let Some(event_type) = message - .headers() - .iter() - .find(|h| h.name().as_str() == ":event-type") - { - if event_type - .value() - .as_string() - .map(|s| s.as_str() == "initial-response") - .unwrap_or(false) - { - return Ok(Some(message)); - } - } else { - // Buffer the message so that it can be returned by the next call to `recv()` - self.buffered_message = Some(message); - } - } - Ok(None) - } - - /// Asynchronously tries to receive a message from the stream. If the stream has ended, - /// it returns an `Ok(None)`. If there is a transport layer error, it will return - /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned - /// messages. - pub async fn recv(&mut self) -> Result, SdkError> { - if let Some(buffered) = self.buffered_message.take() { - return self.unmarshall(buffered); - } - if let Some(message) = self.next_message().await? { - self.unmarshall(message) - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::{MarshallMessage, Receiver, UnmarshallMessage}; - use crate::body::SdkBody; - use crate::event_stream::{EventStreamInput, MessageStreamAdapter}; - use crate::result::SdkError; - use async_stream::stream; - use bytes::Bytes; - use futures_core::Stream; - use futures_util::stream::StreamExt; - use hyper::body::Body; - use smithy_eventstream::error::Error as EventStreamError; - use smithy_eventstream::frame::{ - Header, HeaderValue, Message, SignMessage, SignMessageError, UnmarshalledMessage, - }; - use std::error::Error as StdError; - use std::io::{Error as IOError, ErrorKind}; - - fn encode_initial_response() -> Bytes { - let mut buffer = Vec::new(); - Message::new(Bytes::new()) - .add_header(Header::new( - ":message-type", - HeaderValue::String("event".into()), - )) - .add_header(Header::new( - ":event-type", - HeaderValue::String("initial-response".into()), - )) - .write_to(&mut buffer) - .unwrap(); - buffer.into() - } - - fn encode_message(message: &str) -> Bytes { - let mut buffer = Vec::new(); - Message::new(Bytes::copy_from_slice(message.as_bytes())) - .write_to(&mut buffer) - .unwrap(); - buffer.into() - } - - #[derive(Debug)] - struct FakeError; - impl std::fmt::Display for FakeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "FakeError") - } - } - impl StdError for FakeError {} - - #[derive(Debug, Eq, PartialEq)] - struct TestMessage(String); +mod input; +mod output; - #[derive(Debug)] - struct Marshaller; - impl MarshallMessage for Marshaller { - type Input = TestMessage; - - fn marshall(&self, input: Self::Input) -> Result { - Ok(Message::new(input.0.as_bytes().to_vec())) - } - } - - #[derive(Debug)] - struct Unmarshaller; - impl UnmarshallMessage for Unmarshaller { - type Output = TestMessage; - type Error = EventStreamError; - - fn unmarshall( - &self, - message: &Message, - ) -> Result, EventStreamError> { - Ok(UnmarshalledMessage::Event(TestMessage( - std::str::from_utf8(&message.payload()[..]).unwrap().into(), - ))) - } - } - - #[tokio::test] - async fn receive_success() { - let chunks: Vec> = - vec![Ok(encode_message("one")), Ok(encode_message("two"))]; - let chunk_stream = futures_util::stream::iter(chunks); - let body = SdkBody::from(Body::wrap_stream(chunk_stream)); - let mut receiver = Receiver::::new(Unmarshaller, body); - assert_eq!( - TestMessage("one".into()), - receiver.recv().await.unwrap().unwrap() - ); - assert_eq!( - TestMessage("two".into()), - receiver.recv().await.unwrap().unwrap() - ); - } - - #[tokio::test] - async fn receive_network_failure() { - let chunks: Vec> = vec![ - Ok(encode_message("one")), - Err(IOError::new(ErrorKind::ConnectionReset, FakeError)), - ]; - let chunk_stream = futures_util::stream::iter(chunks); - let body = SdkBody::from(Body::wrap_stream(chunk_stream)); - let mut receiver = Receiver::::new(Unmarshaller, body); - assert_eq!( - TestMessage("one".into()), - receiver.recv().await.unwrap().unwrap() - ); - assert!(matches!( - receiver.recv().await, - Err(SdkError::DispatchFailure(_)) - )); - } - - #[tokio::test] - async fn receive_message_parse_failure() { - let chunks: Vec> = vec![ - Ok(encode_message("one")), - // A zero length message will be invalid. We need to provide a minimum of 12 bytes - // for the MessageFrameDecoder to actually start parsing it. - Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])), - ]; - let chunk_stream = futures_util::stream::iter(chunks); - let body = SdkBody::from(Body::wrap_stream(chunk_stream)); - let mut receiver = Receiver::::new(Unmarshaller, body); - assert_eq!( - TestMessage("one".into()), - receiver.recv().await.unwrap().unwrap() - ); - assert!(matches!( - receiver.recv().await, - Err(SdkError::DispatchFailure(_)) - )); - } - - #[tokio::test] - async fn receive_initial_response() { - let chunks: Vec> = - vec![Ok(encode_initial_response()), Ok(encode_message("one"))]; - let chunk_stream = futures_util::stream::iter(chunks); - let body = SdkBody::from(Body::wrap_stream(chunk_stream)); - let mut receiver = Receiver::::new(Unmarshaller, body); - assert!(receiver.try_recv_initial().await.unwrap().is_some()); - assert_eq!( - TestMessage("one".into()), - receiver.recv().await.unwrap().unwrap() - ); - } - - #[tokio::test] - async fn receive_no_initial_response() { - let chunks: Vec> = - vec![Ok(encode_message("one")), Ok(encode_message("two"))]; - let chunk_stream = futures_util::stream::iter(chunks); - let body = SdkBody::from(Body::wrap_stream(chunk_stream)); - let mut receiver = Receiver::::new(Unmarshaller, body); - assert!(receiver.try_recv_initial().await.unwrap().is_none()); - assert_eq!( - TestMessage("one".into()), - receiver.recv().await.unwrap().unwrap() - ); - assert_eq!( - TestMessage("two".into()), - receiver.recv().await.unwrap().unwrap() - ); - } - - fn assert_send() {} - - #[tokio::test] - async fn receiver_is_send() { - assert_send::>(); - } - - #[derive(Debug)] - struct TestServiceError; - impl std::fmt::Display for TestServiceError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "TestServiceError") - } - } - impl StdError for TestServiceError {} - - #[derive(Debug)] - struct TestSigner; - impl SignMessage for TestSigner { - fn sign(&mut self, message: Message) -> Result { - let mut buffer = Vec::new(); - message.write_to(&mut buffer).unwrap(); - Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true)))) - } - - fn sign_empty(&mut self) -> Result { - Ok(Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))) - } - } - - fn check_compatible_with_hyper_wrap_stream(stream: S) -> S - where - S: Stream> + Send + 'static, - O: Into + 'static, - E: Into> + 'static, - { - stream - } - - #[tokio::test] - async fn message_stream_adapter_success() { - let stream = stream! { - yield Ok(TestMessage("test".into())); - }; - let mut adapter = - check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( - Marshaller, TestSigner, Box::pin(stream) - )); - - let mut sent_bytes = adapter.next().await.unwrap().unwrap(); - let sent = Message::read_from(&mut sent_bytes).unwrap(); - assert_eq!("signed", sent.headers()[0].name().as_str()); - assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value()); - let inner = Message::read_from(&mut (&sent.payload()[..])).unwrap(); - assert_eq!(&b"test"[..], &inner.payload()[..]); - - let mut end_signal_bytes = adapter.next().await.unwrap().unwrap(); - let end_signal = Message::read_from(&mut end_signal_bytes).unwrap(); - assert_eq!("signed", end_signal.headers()[0].name().as_str()); - assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value()); - assert_eq!(0, end_signal.payload().len()); - } - - #[tokio::test] - async fn message_stream_adapter_construction_failure() { - let stream = stream! { - yield Err(EventStreamError::InvalidMessageLength.into()); - }; - let mut adapter = - check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( - Marshaller, TestSigner, Box::pin(stream) - )); +pub type BoxError = Box; - let result = adapter.next().await.unwrap(); - assert!(result.is_err()); - assert!(matches!( - result.err().unwrap(), - SdkError::ConstructionFailure(_) - )); - } +#[doc(inline)] +pub use input::{EventStreamInput, MessageStreamAdapter}; - // Verify the developer experience for this compiles - #[allow(unused)] - fn event_stream_input_ergonomics() { - fn check(input: impl Into>) { - let _: EventStreamInput = input.into(); - } - check(stream! { - yield Ok(TestMessage("test".into())); - }); - check(stream! { - yield Err(EventStreamError::InvalidMessageLength.into()); - }); - } -} +#[doc(inline)] +pub use output::{Error, RawMessage, Receiver}; diff --git a/rust-runtime/smithy-http/src/event_stream/input.rs b/rust-runtime/smithy-http/src/event_stream/input.rs new file mode 100644 index 000000000..810a269df --- /dev/null +++ b/rust-runtime/smithy-http/src/event_stream/input.rs @@ -0,0 +1,257 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use super::BoxError; +use crate::result::SdkError; +use bytes::Bytes; +use futures_core::Stream; +use pin_project::pin_project; +use smithy_eventstream::frame::{MarshallMessage, SignMessage}; +use std::error::Error as StdError; +use std::fmt; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Input type for Event Streams. +pub struct EventStreamInput { + input_stream: Pin> + Send>>, +} + +impl fmt::Debug for EventStreamInput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "EventStreamInput(Box)") + } +} + +impl EventStreamInput { + #[doc(hidden)] + pub fn into_body_stream( + self, + marshaller: impl MarshallMessage + Send + Sync + 'static, + signer: impl SignMessage + Send + Sync + 'static, + ) -> MessageStreamAdapter { + MessageStreamAdapter::new(marshaller, signer, self.input_stream) + } +} + +impl From for EventStreamInput +where + S: Stream> + Send + 'static, +{ + fn from(stream: S) -> Self { + EventStreamInput { + input_stream: Box::pin(stream), + } + } +} + +/// Adapts a `Stream` to a signed `Stream` by using the provided +/// message marshaller and signer implementations. +/// +/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be +/// marshalled into an Event Stream frame, (e.g., if the message payload was too large). +#[pin_project] +pub struct MessageStreamAdapter { + marshaller: Box + Send + Sync>, + signer: Box, + #[pin] + stream: Pin> + Send>>, + end_signal_sent: bool, + _phantom: PhantomData, +} + +impl MessageStreamAdapter +where + E: StdError + Send + Sync + 'static, +{ + pub fn new( + marshaller: impl MarshallMessage + Send + Sync + 'static, + signer: impl SignMessage + Send + Sync + 'static, + stream: Pin> + Send>>, + ) -> Self { + MessageStreamAdapter { + marshaller: Box::new(marshaller), + signer: Box::new(signer), + stream, + end_signal_sent: false, + _phantom: Default::default(), + } + } +} + +impl Stream for MessageStreamAdapter +where + E: StdError + Send + Sync + 'static, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match this.stream.poll_next(cx) { + Poll::Ready(message_option) => { + if let Some(message_result) = message_option { + let message_result = + message_result.map_err(|err| SdkError::ConstructionFailure(err)); + let message = this + .marshaller + .marshall(message_result?) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + let message = this + .signer + .sign(message) + .map_err(|err| SdkError::ConstructionFailure(err))?; + let mut buffer = Vec::new(); + message + .write_to(&mut buffer) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + Poll::Ready(Some(Ok(Bytes::from(buffer)))) + } else if !*this.end_signal_sent { + *this.end_signal_sent = true; + let mut buffer = Vec::new(); + this.signer + .sign_empty() + .map_err(|err| SdkError::ConstructionFailure(err))? + .write_to(&mut buffer) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + Poll::Ready(Some(Ok(Bytes::from(buffer)))) + } else { + Poll::Ready(None) + } + } + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use super::MarshallMessage; + use crate::event_stream::{EventStreamInput, MessageStreamAdapter}; + use crate::result::SdkError; + use async_stream::stream; + use bytes::Bytes; + use futures_core::Stream; + use futures_util::stream::StreamExt; + use smithy_eventstream::error::Error as EventStreamError; + use smithy_eventstream::frame::{Header, HeaderValue, Message, SignMessage, SignMessageError}; + use std::error::Error as StdError; + + #[derive(Debug)] + struct FakeError; + impl std::fmt::Display for FakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FakeError") + } + } + impl StdError for FakeError {} + + #[derive(Debug, Eq, PartialEq)] + struct TestMessage(String); + + #[derive(Debug)] + struct Marshaller; + impl MarshallMessage for Marshaller { + type Input = TestMessage; + + fn marshall(&self, input: Self::Input) -> Result { + Ok(Message::new(input.0.as_bytes().to_vec())) + } + } + + #[derive(Debug)] + struct TestServiceError; + impl std::fmt::Display for TestServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestServiceError") + } + } + impl StdError for TestServiceError {} + + #[derive(Debug)] + struct TestSigner; + impl SignMessage for TestSigner { + fn sign(&mut self, message: Message) -> Result { + let mut buffer = Vec::new(); + message.write_to(&mut buffer).unwrap(); + Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true)))) + } + + fn sign_empty(&mut self) -> Result { + Ok(Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))) + } + } + + fn check_compatible_with_hyper_wrap_stream(stream: S) -> S + where + S: Stream> + Send + 'static, + O: Into + 'static, + E: Into> + 'static, + { + stream + } + + #[tokio::test] + async fn message_stream_adapter_success() { + let stream = stream! { + yield Ok(TestMessage("test".into())); + }; + let mut adapter = + check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< + TestMessage, + TestServiceError, + >::new( + Marshaller, TestSigner, Box::pin(stream) + )); + + let mut sent_bytes = adapter.next().await.unwrap().unwrap(); + let sent = Message::read_from(&mut sent_bytes).unwrap(); + assert_eq!("signed", sent.headers()[0].name().as_str()); + assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value()); + let inner = Message::read_from(&mut (&sent.payload()[..])).unwrap(); + assert_eq!(&b"test"[..], &inner.payload()[..]); + + let mut end_signal_bytes = adapter.next().await.unwrap().unwrap(); + let end_signal = Message::read_from(&mut end_signal_bytes).unwrap(); + assert_eq!("signed", end_signal.headers()[0].name().as_str()); + assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value()); + assert_eq!(0, end_signal.payload().len()); + } + + #[tokio::test] + async fn message_stream_adapter_construction_failure() { + let stream = stream! { + yield Err(EventStreamError::InvalidMessageLength.into()); + }; + let mut adapter = + check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< + TestMessage, + TestServiceError, + >::new( + Marshaller, TestSigner, Box::pin(stream) + )); + + let result = adapter.next().await.unwrap(); + assert!(result.is_err()); + assert!(matches!( + result.err().unwrap(), + SdkError::ConstructionFailure(_) + )); + } + + // Verify the developer experience for this compiles + #[allow(unused)] + fn event_stream_input_ergonomics() { + fn check(input: impl Into>) { + let _: EventStreamInput = input.into(); + } + check(stream! { + yield Ok(TestMessage("test".into())); + }); + check(stream! { + yield Err(EventStreamError::InvalidMessageLength.into()); + }); + } +} diff --git a/rust-runtime/smithy-http/src/event_stream/output.rs b/rust-runtime/smithy-http/src/event_stream/output.rs new file mode 100644 index 000000000..45223aaa2 --- /dev/null +++ b/rust-runtime/smithy-http/src/event_stream/output.rs @@ -0,0 +1,533 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use crate::body::SdkBody; +use crate::result::SdkError; +use bytes::Buf; +use bytes::Bytes; +use bytes_utils::SegmentedBuf; +use hyper::body::HttpBody; +use smithy_eventstream::frame::{ + DecodedFrame, Message, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage, +}; +use std::error::Error as StdError; +use std::fmt; +use std::marker::PhantomData; +use std::mem; + +/// Wrapper around SegmentedBuf that tracks the state of the stream. +#[derive(Debug)] +enum RecvBuf { + /// Nothing has been buffered yet. + Empty, + /// Some data has been buffered. + /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary + Partial(SegmentedBuf), + /// The end of the stream has been reached, but there may still be some buffered data. + EosPartial(SegmentedBuf), +} + +impl RecvBuf { + /// Returns true if there's more buffered data. + fn has_data(&self) -> bool { + match self { + RecvBuf::Empty => false, + RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0, + } + } + + /// Returns true if the stream has ended. + fn is_eos(&self) -> bool { + matches!(self, RecvBuf::EosPartial(_)) + } + + /// Returns a mutable reference to the underlying buffered data. + fn buffered(&mut self) -> &mut SegmentedBuf { + match self { + RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"), + RecvBuf::Partial(segmented) => segmented, + RecvBuf::EosPartial(segmented) => segmented, + } + } + + /// Returns a new `RecvBuf` with additional data buffered. This will only allocate + /// if the `RecvBuf` was previously empty. + fn with_partial(self, partial: Bytes) -> Self { + match self { + RecvBuf::Empty => { + let mut segmented = SegmentedBuf::new(); + segmented.push(partial); + RecvBuf::Partial(segmented) + } + RecvBuf::Partial(mut segmented) => { + segmented.push(partial); + RecvBuf::Partial(segmented) + } + RecvBuf::EosPartial(_) => { + panic!("cannot buffer more data after the stream has ended; this is a bug") + } + } + } + + /// Returns a `RecvBuf` that has reached end of stream. + fn ended(self) -> Self { + match self { + RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()), + RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented), + RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"), + } + } +} + +/// Raw message from a [`Receiver`] when a [`SdkError::ResponseError`] is returned. +#[derive(Debug)] +#[non_exhaustive] +pub enum RawMessage { + /// Message was decoded into a valid frame, but failed to unmarshall into a modeled type. + Decoded(Message), + /// Message failed to be decoded into a valid frame. The raw bytes may not be available in the + /// case where decoding consumed the buffer. + Invalid(Option), +} + +impl From<&mut SegmentedBuf> for RawMessage { + fn from(buf: &mut SegmentedBuf) -> Self { + Self::Invalid(Some(buf.copy_to_bytes(buf.remaining()))) + } +} + +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// The stream ended before a complete message frame was received. + #[non_exhaustive] + UnexpectedEndOfStream, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnexpectedEndOfStream => write!(f, "unexpected end of stream"), + } + } +} + +impl StdError for Error {} + +/// Receives Smithy-modeled messages out of an Event Stream. +#[derive(Debug)] +pub struct Receiver { + unmarshaller: Box + Send>, + decoder: MessageFrameDecoder, + buffer: RecvBuf, + body: SdkBody, + /// Event Stream has optional initial response frames an with `:message-type` of + /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an + /// initial response, then the message will be stored in `buffered_message` so that it can + /// be returned with the next call of `recv()`. + buffered_message: Option, + _phantom: PhantomData, +} + +impl Receiver { + /// Creates a new `Receiver` with the given message unmarshaller and SDK body. + pub fn new( + unmarshaller: impl UnmarshallMessage + Send + 'static, + body: SdkBody, + ) -> Self { + Receiver { + unmarshaller: Box::new(unmarshaller), + decoder: MessageFrameDecoder::new(), + buffer: RecvBuf::Empty, + body, + buffered_message: None, + _phantom: Default::default(), + } + } + + fn unmarshall(&self, message: Message) -> Result, SdkError> { + match self.unmarshaller.unmarshall(&message) { + Ok(unmarshalled) => match unmarshalled { + UnmarshalledMessage::Event(event) => Ok(Some(event)), + UnmarshalledMessage::Error(err) => Err(SdkError::ServiceError { + err, + raw: RawMessage::Decoded(message), + }), + }, + Err(err) => Err(SdkError::ResponseError { + err: Box::new(err), + raw: RawMessage::Decoded(message), + }), + } + } + + async fn buffer_next_chunk(&mut self) -> Result<(), SdkError> { + if !self.buffer.is_eos() { + let next_chunk = self + .body + .data() + .await + .transpose() + .map_err(|err| SdkError::DispatchFailure(err))?; + let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty); + if let Some(chunk) = next_chunk { + self.buffer = buffer.with_partial(chunk); + } else { + self.buffer = buffer.ended(); + } + } + Ok(()) + } + + async fn next_message(&mut self) -> Result, SdkError> { + while !self.buffer.is_eos() { + if self.buffer.has_data() { + if let DecodedFrame::Complete(message) = self + .decoder + .decode_frame(self.buffer.buffered()) + .map_err(|err| SdkError::ResponseError { + err: Box::new(err), + raw: RawMessage::Invalid(None), // the buffer has been consumed + })? + { + return Ok(Some(message)); + } + } + + self.buffer_next_chunk().await?; + } + if self.buffer.has_data() { + return Err(SdkError::ResponseError { + err: Error::UnexpectedEndOfStream.into(), + raw: self.buffer.buffered().into(), + }); + } + Ok(None) + } + + /// Tries to receive the initial response message that has `:event-type` of `initial-response`. + /// If a different event type is received, then it is buffered and `Ok(None)` is returned. + #[doc(hidden)] + pub async fn try_recv_initial(&mut self) -> Result, SdkError> { + if let Some(message) = self.next_message().await? { + if let Some(event_type) = message + .headers() + .iter() + .find(|h| h.name().as_str() == ":event-type") + { + if event_type + .value() + .as_string() + .map(|s| s.as_str() == "initial-response") + .unwrap_or(false) + { + return Ok(Some(message)); + } + } else { + // Buffer the message so that it can be returned by the next call to `recv()` + self.buffered_message = Some(message); + } + } + Ok(None) + } + + /// Asynchronously tries to receive a message from the stream. If the stream has ended, + /// it returns an `Ok(None)`. If there is a transport layer error, it will return + /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned + /// messages. + pub async fn recv(&mut self) -> Result, SdkError> { + if let Some(buffered) = self.buffered_message.take() { + return self.unmarshall(buffered); + } + if let Some(message) = self.next_message().await? { + self.unmarshall(message) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::{Receiver, UnmarshallMessage}; + use crate::body::SdkBody; + use crate::result::SdkError; + use bytes::Bytes; + use hyper::body::Body; + use smithy_eventstream::error::Error as EventStreamError; + use smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshalledMessage}; + use std::error::Error as StdError; + use std::io::{Error as IOError, ErrorKind}; + + fn encode_initial_response() -> Bytes { + let mut buffer = Vec::new(); + Message::new(Bytes::new()) + .add_header(Header::new( + ":message-type", + HeaderValue::String("event".into()), + )) + .add_header(Header::new( + ":event-type", + HeaderValue::String("initial-response".into()), + )) + .write_to(&mut buffer) + .unwrap(); + buffer.into() + } + + fn encode_message(message: &str) -> Bytes { + let mut buffer = Vec::new(); + Message::new(Bytes::copy_from_slice(message.as_bytes())) + .write_to(&mut buffer) + .unwrap(); + buffer.into() + } + + #[derive(Debug)] + struct FakeError; + impl std::fmt::Display for FakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FakeError") + } + } + impl StdError for FakeError {} + + #[derive(Debug, Eq, PartialEq)] + struct TestMessage(String); + + #[derive(Debug)] + struct Unmarshaller; + impl UnmarshallMessage for Unmarshaller { + type Output = TestMessage; + type Error = EventStreamError; + + fn unmarshall( + &self, + message: &Message, + ) -> Result, EventStreamError> { + Ok(UnmarshalledMessage::Event(TestMessage( + std::str::from_utf8(&message.payload()[..]).unwrap().into(), + ))) + } + } + + #[tokio::test] + async fn receive_success() { + let chunks: Vec> = + vec![Ok(encode_message("one")), Ok(encode_message("two"))]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!(None, receiver.recv().await.unwrap()); + } + + #[tokio::test] + async fn receive_last_chunk_empty() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + Ok(encode_message("two")), + Ok(Bytes::from_static(&[])), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!(None, receiver.recv().await.unwrap()); + } + + #[tokio::test] + async fn receive_last_chunk_not_full_message() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + Ok(encode_message("two")), + Ok(encode_message("three").split_to(10)), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert!(matches!( + receiver.recv().await, + Err(SdkError::ResponseError { .. }), + )); + } + + #[tokio::test] + async fn receive_last_chunk_has_multiple_messages() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + Ok(encode_message("two")), + Ok(Bytes::from( + [encode_message("three"), encode_message("four")].concat(), + )), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("three".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("four".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!(None, receiver.recv().await.unwrap()); + } + + proptest::proptest! { + #[test] + fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) { + let combined = Bytes::from([ + encode_message("one"), + encode_message("two"), + encode_message("three"), + encode_message("four"), + encode_message("five"), + encode_message("six"), + encode_message("seven"), + encode_message("eight"), + ].concat()); + + let midpoint = combined.len() / 2; + let (start, boundary1, boundary2, end) = ( + 0, + b1 % midpoint, + midpoint + b2 % midpoint, + combined.len() + ); + println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + let chunks: Vec> = vec![ + Ok(Bytes::copy_from_slice(&combined[start..boundary1])), + Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])), + Ok(Bytes::copy_from_slice(&combined[boundary2..end])), + ]; + + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] { + assert_eq!( + TestMessage((*payload).into()), + receiver.recv().await.unwrap().unwrap() + ); + } + assert_eq!(None, receiver.recv().await.unwrap()); + }); + } + } + + #[tokio::test] + async fn receive_network_failure() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + Err(IOError::new(ErrorKind::ConnectionReset, FakeError)), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert!(matches!( + receiver.recv().await, + Err(SdkError::DispatchFailure(_)) + )); + } + + #[tokio::test] + async fn receive_message_parse_failure() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + // A zero length message will be invalid. We need to provide a minimum of 12 bytes + // for the MessageFrameDecoder to actually start parsing it. + Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert!(matches!( + receiver.recv().await, + Err(SdkError::ResponseError { .. }) + )); + } + + #[tokio::test] + async fn receive_initial_response() { + let chunks: Vec> = + vec![Ok(encode_initial_response()), Ok(encode_message("one"))]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert!(receiver.try_recv_initial().await.unwrap().is_some()); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + } + + #[tokio::test] + async fn receive_no_initial_response() { + let chunks: Vec> = + vec![Ok(encode_message("one")), Ok(encode_message("two"))]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = Receiver::::new(Unmarshaller, body); + assert!(receiver.try_recv_initial().await.unwrap().is_none()); + assert_eq!( + TestMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + TestMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + } + + fn assert_send() {} + + #[tokio::test] + async fn receiver_is_send() { + assert_send::>(); + } +} -- GitLab