diff --git a/CHANGELOG.md b/CHANGELOG.md index dea9fd0e1631b972c06e7ec35eaf318ea46d4625..234bbe8bb1bd0abaa2c132c64a90313bdf716d70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ vNext (Month Day Year) - (When complete) Add profile file provider for region (#594, #xyz) - Add AssumeRoleProvider parser implementation. (#632) - The closure passed to `async_provide_credentials_fn` can now borrow values (#637) +- Add `Sender`/`Receiver` implementations for Event Stream (#639) v0.19 (August 3rd, 2021) ------------------------ diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 7ee0b04a8b57513e74498c808bdbcfc7354bb2f6..0ba5c6ed5e5aa6842d0b2992fc0af0aa8fec32a7 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -21,24 +21,25 @@ val smithyVersion: String by project val sdkOutputDir = buildDir.resolve("aws-sdk") val runtimeModules = listOf( + "protocol-test-helpers", "smithy-async", - "smithy-types", - "smithy-json", - "smithy-query", - "smithy-xml", + "smithy-client", + "smithy-eventstream", "smithy-http", "smithy-http-tower", - "smithy-client", - "protocol-test-helpers" + "smithy-json", + "smithy-query", + "smithy-types", + "smithy-xml" ) val awsModules = listOf( "aws-auth", + "aws-auth-providers", "aws-endpoint", - "aws-types", + "aws-http", "aws-hyper", "aws-sig-auth", - "aws-http", - "aws-auth-providers" + "aws-types" ) buildscript { diff --git a/rust-runtime/smithy-eventstream/src/frame.rs b/rust-runtime/smithy-eventstream/src/frame.rs index 1f9d04dc4d6568b63445f706e36fb5a42138588b..5d02ea5c496bfe17ff212bb46070c20381f6a590 100644 --- a/rust-runtime/smithy-eventstream/src/frame.rs +++ b/rust-runtime/smithy-eventstream/src/frame.rs @@ -11,6 +11,7 @@ use crate::error::Error; use crate::str_bytes::StrBytes; use bytes::{Buf, BufMut, Bytes}; use std::convert::{TryFrom, TryInto}; +use std::error::Error as StdError; use std::mem::size_of; const PRELUDE_LENGTH_BYTES: u32 = 3 * size_of::() as u32; @@ -19,6 +20,29 @@ const MESSAGE_CRC_LENGTH_BYTES: u32 = size_of::() as u32; const MAX_HEADER_NAME_LEN: usize = 255; const MIN_HEADER_LEN: usize = 2; +pub type SignMessageError = Box; + +/// Signs an Event Stream message. +pub trait SignMessage { + fn sign(&mut self, message: Message) -> Result; +} + +/// Converts a Smithy modeled Event Stream type into a [`Message`](Message). +pub trait MarshallMessage { + /// Smithy modeled input type to convert from. + type Input; + + fn marshall(&self, input: Self::Input) -> Result; +} + +/// Converts an Event Stream [`Message`](Message) into a Smithy modeled type. +pub trait UnmarshallMessage { + /// Smithy modeled type to convert into. + type Output; + + fn unmarshall(&self, message: Message) -> Result; +} + mod value { use crate::error::Error; use crate::frame::checked; diff --git a/rust-runtime/smithy-http/Cargo.toml b/rust-runtime/smithy-http/Cargo.toml index 1297f830c577a6e62465d9ce4a506e3e34ed19f7..752bc022a510775fd04e58de2d6f3f7b7586da34 100644 --- a/rust-runtime/smithy-http/Cargo.toml +++ b/rust-runtime/smithy-http/Cargo.toml @@ -7,11 +7,14 @@ license = "Apache-2.0" [features] bytestream-util = ["tokio/fs", "tokio-util/io"] -default = ["bytestream-util"] +event-stream = ["smithy-eventstream"] +default = ["bytestream-util", "event-stream"] [dependencies] smithy-types = { path = "../smithy-types" } +smithy-eventstream = { path = "../smithy-eventstream", optional = true } bytes = "1" +bytes-utils = "0.1" http-body = "0.4.0" http = "0.2.3" thiserror = "1" @@ -23,12 +26,14 @@ tracing = "0.1" hyper = "0.14.5" # ByteStream internals -bytes-utils = "0.1.1" futures-core = "0.3.14" tokio = { version = "1.6", optional = true } tokio-util = { version = "0.6", optional = true} [dev-dependencies] +async-stream = "0.3" +futures-util = "0.3" +hyper = { version = "0.14.5", features = ["stream"] } proptest = "1" tokio = {version = "1.6", features = ["macros", "rt", "fs", "io-util"]} tokio-stream = "0.1.5" diff --git a/rust-runtime/smithy-http/src/event_stream.rs b/rust-runtime/smithy-http/src/event_stream.rs new file mode 100644 index 0000000000000000000000000000000000000000..137534badbe1e9df1c18e49a02c87d42198e4c06 --- /dev/null +++ b/rust-runtime/smithy-http/src/event_stream.rs @@ -0,0 +1,302 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +//! Provides Sender/Receiver implementations for Event Stream codegen. + +use crate::body::SdkBody; +use crate::result::SdkError; +use bytes::Bytes; +use bytes_utils::SegmentedBuf; +use futures_core::Stream; +use hyper::body::HttpBody; +use pin_project::pin_project; +use smithy_eventstream::frame::{ + DecodedFrame, MarshallMessage, MessageFrameDecoder, SignMessage, UnmarshallMessage, +}; +use std::error::Error as StdError; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Adapts a `Stream` to a signed `Stream` by using the provided +/// message marshaller and signer implementations. +/// +/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be +/// marshalled into an Event Stream frame, (e.g., if the message payload was too large). +#[pin_project] +pub struct MessageStreamAdapter { + marshaller: Box + Send + Sync>, + signer: Box, + #[pin] + stream: Pin> + Send + Sync>>, +} + +impl MessageStreamAdapter { + pub fn new( + marshaller: impl MarshallMessage + Send + Sync + 'static, + signer: impl SignMessage + Send + Sync + 'static, + stream: impl Stream> + Send + Sync + 'static, + ) -> Self { + MessageStreamAdapter { + marshaller: Box::new(marshaller), + signer: Box::new(signer), + stream: Box::pin(stream), + } + } +} + +impl Stream for MessageStreamAdapter { + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match this.stream.poll_next(cx) { + Poll::Ready(message_option) => { + if let Some(message_result) = message_option { + let message_result = + message_result.map_err(|err| SdkError::ConstructionFailure(Box::new(err))); + let message = this + .marshaller + .marshall(message_result?) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + let message = this + .signer + .sign(message) + .map_err(|err| SdkError::ConstructionFailure(err))?; + let mut buffer = Vec::new(); + message + .write_to(&mut buffer) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + Poll::Ready(Some(Ok(Bytes::from(buffer)))) + } else { + Poll::Ready(None) + } + } + Poll::Pending => Poll::Pending, + } + } +} + +/// Receives Smithy-modeled messages out of an Event Stream. +pub struct Receiver { + unmarshaller: Box>, + decoder: MessageFrameDecoder, + buffer: SegmentedBuf, + body: SdkBody, + _phantom: PhantomData, +} + +impl Receiver { + /// Creates a new `Receiver` with the given message unmarshaller and SDK body. + pub fn new(unmarshaller: impl UnmarshallMessage + 'static, body: SdkBody) -> Self { + Receiver { + unmarshaller: Box::new(unmarshaller), + decoder: MessageFrameDecoder::new(), + buffer: SegmentedBuf::new(), + body, + _phantom: Default::default(), + } + } + + /// Asynchronously tries to receive a message from the stream. If the stream has ended, + /// it returns an `Ok(None)`. If there is a transport layer error, it will return + /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned + /// messages. + pub async fn recv(&mut self) -> Result, SdkError> { + let next_chunk = self + .body + .data() + .await + .transpose() + .map_err(|err| SdkError::DispatchFailure(err))?; + if let Some(chunk) = next_chunk { + // The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary + self.buffer.push(chunk); + if let DecodedFrame::Complete(message) = self + .decoder + .decode_frame(&mut self.buffer) + .map_err(|err| SdkError::DispatchFailure(Box::new(err)))? + { + return Ok(Some( + self.unmarshaller + .unmarshall(message) + .map_err(|err| SdkError::DispatchFailure(Box::new(err)))?, + )); + } + } + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::{MarshallMessage, Receiver, UnmarshallMessage}; + use crate::body::SdkBody; + use crate::event_stream::MessageStreamAdapter; + use crate::result::SdkError; + use async_stream::stream; + use bytes::Bytes; + use futures_core::Stream; + use futures_util::stream::StreamExt; + use hyper::body::Body; + use smithy_eventstream::error::Error as EventStreamError; + use smithy_eventstream::frame::{Header, HeaderValue, Message, SignMessage, SignMessageError}; + use std::error::Error as StdError; + use std::io::{Error as IOError, ErrorKind}; + + fn encode_message(message: &str) -> Bytes { + let mut buffer = Vec::new(); + Message::new(Bytes::copy_from_slice(message.as_bytes())) + .write_to(&mut buffer) + .unwrap(); + buffer.into() + } + + #[derive(Debug)] + struct FakeError; + impl std::fmt::Display for FakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FakeError") + } + } + impl StdError for FakeError {} + + #[derive(Debug, Eq, PartialEq)] + struct UnmarshalledMessage(String); + + struct Marshaller; + impl MarshallMessage for Marshaller { + type Input = UnmarshalledMessage; + + fn marshall(&self, input: Self::Input) -> Result { + Ok(Message::new(input.0.as_bytes().to_vec())) + } + } + + struct Unmarshaller; + impl UnmarshallMessage for Unmarshaller { + type Output = UnmarshalledMessage; + + fn unmarshall(&self, message: Message) -> Result { + Ok(UnmarshalledMessage( + std::str::from_utf8(&message.payload()[..]).unwrap().into(), + )) + } + } + + #[tokio::test] + async fn receive_success() { + let chunks: Vec> = + vec![Ok(encode_message("one")), Ok(encode_message("two"))]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = + Receiver::::new(Unmarshaller, body); + assert_eq!( + UnmarshalledMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert_eq!( + UnmarshalledMessage("two".into()), + receiver.recv().await.unwrap().unwrap() + ); + } + + #[tokio::test] + async fn receive_network_failure() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + Err(IOError::new(ErrorKind::ConnectionReset, FakeError)), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = + Receiver::::new(Unmarshaller, body); + assert_eq!( + UnmarshalledMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert!(matches!( + receiver.recv().await, + Err(SdkError::DispatchFailure(_)) + )); + } + + #[tokio::test] + async fn receive_message_parse_failure() { + let chunks: Vec> = vec![ + Ok(encode_message("one")), + // A zero length message will be invalid. We need to provide a minimum of 12 bytes + // for the MessageFrameDecoder to actually start parsing it. + Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])), + ]; + let chunk_stream = futures_util::stream::iter(chunks); + let body = SdkBody::from(Body::wrap_stream(chunk_stream)); + let mut receiver = + Receiver::::new(Unmarshaller, body); + assert_eq!( + UnmarshalledMessage("one".into()), + receiver.recv().await.unwrap().unwrap() + ); + assert!(matches!( + receiver.recv().await, + Err(SdkError::DispatchFailure(_)) + )); + } + + struct TestSigner; + impl SignMessage for TestSigner { + fn sign(&mut self, message: Message) -> Result { + let mut buffer = Vec::new(); + message.write_to(&mut buffer).unwrap(); + Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true)))) + } + } + + fn check_compatible_with_hyper_wrap_stream(stream: S) -> S + where + S: Stream> + Send + 'static, + O: Into + 'static, + E: Into> + 'static, + { + stream + } + + #[tokio::test] + async fn message_stream_adapter_success() { + let stream = stream! { + yield Ok(UnmarshalledMessage("test".into())); + }; + let mut adapter = + check_compatible_with_hyper_wrap_stream( + MessageStreamAdapter::<_, EventStreamError>::new(Marshaller, TestSigner, stream), + ); + + let mut sent_bytes = adapter.next().await.unwrap().unwrap(); + let sent = Message::read_from(&mut sent_bytes).unwrap(); + assert_eq!("signed", sent.headers()[0].name().as_str()); + assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value()); + let inner = Message::read_from(&mut (&sent.payload()[..])).unwrap(); + assert_eq!(&b"test"[..], &inner.payload()[..]); + } + + #[tokio::test] + async fn message_stream_adapter_construction_failure() { + let stream = stream! { + yield Err(EventStreamError::InvalidMessageLength); + }; + let mut adapter = + check_compatible_with_hyper_wrap_stream( + MessageStreamAdapter::::new(Marshaller, TestSigner, stream), + ); + + let result = adapter.next().await.unwrap(); + assert!(result.is_err()); + assert!(matches!( + result.err().unwrap(), + SdkError::ConstructionFailure(_) + )); + } +} diff --git a/rust-runtime/smithy-http/src/lib.rs b/rust-runtime/smithy-http/src/lib.rs index 92837e3a16545b86dd8be6732218f9ce2b946b2c..cbec956ea9d505ac60b6d788b7d7ea745bfde60c 100644 --- a/rust-runtime/smithy-http/src/lib.rs +++ b/rust-runtime/smithy-http/src/lib.rs @@ -12,10 +12,14 @@ pub mod header; pub mod label; pub mod middleware; pub mod operation; -mod pin_util; pub mod property_bag; pub mod query; pub mod response; pub mod result; pub mod retry; + +#[cfg(feature = "event-stream")] +pub mod event_stream; + +mod pin_util; mod urlencode;