Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +8 −8 Original line number Diff line number Diff line Loading @@ -125,7 +125,7 @@ class EventStreamUnmarshallerGenerator( } rustBlock("value => ") { rustTemplate( "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", "return Err(#{Error}::unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", *codegenScope, ) } Loading Loading @@ -156,7 +156,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) false -> rustTemplate( "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", *codegenScope, ) } Loading Loading @@ -250,7 +250,7 @@ class EventStreamUnmarshallerGenerator( """ let content_type = response_headers.content_type().unwrap_or_default(); if content_type != ${contentType.dq()} { return Err(#{Error}::Unmarshalling(format!( return Err(#{Error}::unmarshalling(format!( "expected :content-type to be '$contentType', but was '{}'", content_type ))) Loading @@ -269,7 +269,7 @@ class EventStreamUnmarshallerGenerator( rustTemplate( """ std::str::from_utf8(message.payload()) .map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))? .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? """, *codegenScope, ) Loading @@ -288,7 +288,7 @@ class EventStreamUnmarshallerGenerator( """ #{parser}(&message.payload()[..]) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall $memberName: {}", err)) })? """, "parser" to parser, Loading Loading @@ -336,7 +336,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; return Ok(#{UnmarshalledMessage}::Error( #{OpError}::new( Loading @@ -360,7 +360,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; """, "parser" to parser, Loading Loading @@ -394,7 +394,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.SERVER -> { rustTemplate( """ return Err(aws_smithy_eventstream::error::Error::Unmarshalling( return Err(aws_smithy_eventstream::error::Error::unmarshalling( format!("unrecognized exception: {}", response_headers.smithy_type.as_str()), )); """, Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +1 −1 Original line number Diff line number Diff line Loading @@ -124,7 +124,7 @@ class EventStreamErrorMarshallerGenerator( rustTemplate( """ $errorName::Unhandled(_inner) => return Err( #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ), """, *codegenScope, Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +2 −2 Original line number Diff line number Diff line Loading @@ -112,7 +112,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ Self::Input::${UnionGenerator.UnknownVariantName} => return Err( #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ) """, *codegenScope, Loading Loading @@ -212,7 +212,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ #{serializerFn}(&$input) .map_err(|err| #{Error}::Marshalling(format!("{}", err)))? .map_err(|err| #{Error}::marshalling(format!("{}", err)))? """, "serializerFn" to serializerFn, *codegenScope, Loading rust-runtime/aws-smithy-eventstream/src/error.rs +36 −4 Original line number Diff line number Diff line Loading @@ -7,9 +7,8 @@ use aws_smithy_types::DateTime; use std::error::Error as StdError; use std::fmt; #[non_exhaustive] #[derive(Debug)] pub enum Error { pub(crate) enum ErrorKind { HeadersTooLong, HeaderValueTooLong, InvalidHeaderNameLength, Loading @@ -27,12 +26,45 @@ pub enum Error { Unmarshalling(String), } #[derive(Debug)] pub struct Error { kind: ErrorKind, } impl Error { // Used in tests to match on the underlying error kind #[cfg(test)] pub(crate) fn kind(&self) -> &ErrorKind { &self.kind } /// Create an `Error` for failure to marshall a message from a Smithy shape pub fn marshalling(message: impl Into<String>) -> Self { Self { kind: ErrorKind::Marshalling(message.into()), } } /// Create an `Error` for failure to unmarshall a message into a Smithy shape pub fn unmarshalling(message: impl Into<String>) -> Self { Self { kind: ErrorKind::Unmarshalling(message.into()), } } } impl From<ErrorKind> for Error { fn from(kind: ErrorKind) -> Self { Error { kind } } } impl StdError for Error {} impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Error::*; match self { use ErrorKind::*; match &self.kind { HeadersTooLong => write!(f, "headers too long to fit in event stream frame"), HeaderValueTooLong => write!(f, "header value too long to fit in event stream frame"), InvalidHeaderNameLength => write!(f, "invalid header name length"), Loading rust-runtime/aws-smithy-eventstream/src/frame.rs +47 −41 Original line number Diff line number Diff line Loading @@ -7,7 +7,7 @@ use crate::buf::count::CountBuf; use crate::buf::crc::{CrcBuf, CrcBufMut}; use crate::error::Error; use crate::error::{Error, ErrorKind}; use crate::str_bytes::StrBytes; use bytes::{Buf, BufMut, Bytes}; use std::convert::{TryFrom, TryInto}; Loading Loading @@ -75,7 +75,7 @@ pub trait UnmarshallMessage: fmt::Debug { } mod value { use crate::error::Error; use crate::error::{Error, ErrorKind}; use crate::frame::checked; use crate::str_bytes::StrBytes; use aws_smithy_types::DateTime; Loading Loading @@ -179,7 +179,7 @@ mod value { if $buf.remaining() >= size_of::<$size_typ>() { Ok(HeaderValue::$typ($buf.$read_fn())) } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } }; } Loading @@ -198,18 +198,18 @@ mod value { if buffer.remaining() > size_of::<u16>() { let len = buffer.get_u16() as usize; if buffer.remaining() < len { return Err(Error::InvalidHeaderValue); return Err(ErrorKind::InvalidHeaderValue.into()); } let bytes = buffer.copy_to_bytes(len); if value_type == TYPE_STRING { Ok(HeaderValue::String( bytes.try_into().map_err(|_| Error::InvalidUtf8String)?, bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?, )) } else { Ok(HeaderValue::ByteArray(bytes)) } } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_TIMESTAMP => { Loading @@ -217,11 +217,11 @@ mod value { let epoch_millis = buffer.get_i64(); Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis))) } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128), _ => Err(Error::InvalidHeaderValueType(value_type)), _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()), } } Loading @@ -247,19 +247,22 @@ mod value { } ByteArray(val) => { buffer.put_u8(TYPE_BYTE_ARRAY); buffer.put_u16(checked(val.len(), Error::HeaderValueTooLong)?); buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?); buffer.put_slice(&val[..]); } String(val) => { buffer.put_u8(TYPE_STRING); buffer.put_u16(checked(val.as_bytes().len(), Error::HeaderValueTooLong)?); buffer.put_u16(checked( val.as_bytes().len(), ErrorKind::HeaderValueTooLong.into(), )?); buffer.put_slice(&val.as_bytes()[..]); } Timestamp(time) => { buffer.put_u8(TYPE_TIMESTAMP); buffer.put_i64( time.to_millis() .map_err(|_| Error::TimestampValueTooLarge(*time))?, .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?, ); } Uuid(val) => { Loading Loading @@ -329,19 +332,19 @@ impl Header { /// Reads a header from the given `buffer`. fn read_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> { if buffer.remaining() < MIN_HEADER_LEN { return Err(Error::InvalidHeadersLength); return Err(ErrorKind::InvalidHeadersLength.into()); } let mut counting_buf = CountBuf::new(&mut buffer); let name_len = counting_buf.get_u8(); if name_len as usize >= counting_buf.remaining() { return Err(Error::InvalidHeaderNameLength); return Err(ErrorKind::InvalidHeaderNameLength.into()); } let name: StrBytes = counting_buf .copy_to_bytes(name_len as usize) .try_into() .map_err(|_| Error::InvalidUtf8String)?; .map_err(|_| ErrorKind::InvalidUtf8String)?; let value = HeaderValue::read_from(&mut counting_buf)?; Ok((Header::new(name, value), counting_buf.into_count())) } Loading @@ -349,7 +352,7 @@ impl Header { /// Writes the header to the given `buffer`. fn write_to<B: BufMut>(&self, mut buffer: B) -> Result<(), Error> { if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN { return Err(Error::InvalidHeaderNameLength); return Err(ErrorKind::InvalidHeaderNameLength.into()); } buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above")); Loading Loading @@ -414,18 +417,18 @@ impl Message { // If the buffer doesn't have the entire, then error let total_len = crc_buffer.get_u32(); if crc_buffer.remaining() + size_of::<u32>() < total_len as usize { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Validate the prelude let header_len = crc_buffer.get_u32(); let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32()); if expected_crc != prelude_crc { return Err(Error::PreludeChecksumMismatch(expected_crc, prelude_crc)); return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into()); } // The header length can be 0 or >= 2, but must fit within the frame size if header_len == 1 || header_len > max_header_len(total_len)? { return Err(Error::InvalidHeadersLength); return Err(ErrorKind::InvalidHeadersLength.into()); } Ok((total_len, header_len)) } Loading @@ -434,7 +437,7 @@ impl Message { /// the [`MessageFrameDecoder`] instead of this. pub fn read_from<B: Buf>(mut buffer: B) -> Result<Message, Error> { if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Calculate a CRC as we go and read the prelude Loading @@ -444,9 +447,9 @@ impl Message { // Verify we have the full frame before continuing let remaining_len = total_len .checked_sub(PRELUDE_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength)?; .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if crc_buffer.remaining() < remaining_len as usize { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Read headers Loading @@ -456,7 +459,7 @@ impl Message { let (header, bytes_read) = Header::read_from(&mut crc_buffer)?; header_bytes_read += bytes_read; if header_bytes_read > header_len as usize { return Err(Error::InvalidHeaderValue); return Err(ErrorKind::InvalidHeaderValue.into()); } headers.push(header); } Loading @@ -468,7 +471,7 @@ impl Message { let expected_crc = crc_buffer.into_crc(); let message_crc = buffer.get_u32(); if expected_crc != message_crc { return Err(Error::MessageChecksumMismatch(expected_crc, message_crc)); return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into()); } Ok(Message { headers, payload }) Loading @@ -481,8 +484,8 @@ impl Message { header.write_to(&mut headers)?; } let headers_len = checked(headers.len(), Error::HeadersTooLong)?; let payload_len = checked(self.payload.len(), Error::PayloadTooLong)?; let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?; let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?; let message_len = [ PRELUDE_LENGTH_BYTES, headers_len, Loading @@ -491,7 +494,8 @@ impl Message { ] .iter() .try_fold(0u32, |acc, v| { acc.checked_add(*v).ok_or(Error::MessageTooLong) acc.checked_add(*v) .ok_or_else(|| Error::from(ErrorKind::MessageTooLong)) })?; let mut crc_buffer = CrcBufMut::new(buffer); Loading Loading @@ -523,7 +527,7 @@ fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> { fn max_header_len(total_len: u32) -> Result<u32, Error> { total_len .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength) .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> { Loading @@ -531,14 +535,14 @@ fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> { .checked_sub( header_len .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) .ok_or(Error::InvalidHeadersLength)?, .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?, ) .ok_or(Error::InvalidMessageLength) .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } #[cfg(test)] mod message_tests { use crate::error::Error; use crate::error::ErrorKind; use crate::frame::{Header, HeaderValue, Message}; use aws_smithy_types::DateTime; use bytes::Bytes; Loading @@ -546,10 +550,12 @@ mod message_tests { macro_rules! read_message_expect_err { ($bytes:expr, $err:pat) => { let result = Message::read_from(&mut Bytes::from_static($bytes)); let result = result.as_ref(); assert!(result.is_err(), "Expected error, got {:?}", result); assert!( matches!(&result.as_ref(), &Err($err)), matches!(result.err().unwrap().kind(), $err), "Expected {}, got {:?}", stringify!(Err($err)), stringify!($err), result ); }; Loading @@ -559,35 +565,35 @@ mod message_tests { fn invalid_messages() { read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_value_length"), Error::InvalidHeaderValue ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_length_cut_off"), Error::InvalidHeaderValue ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_value_type"), Error::InvalidHeaderValueType(0x60) ErrorKind::InvalidHeaderValueType(0x60) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length"), Error::InvalidHeaderNameLength ErrorKind::InvalidHeaderNameLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_headers_length"), Error::InvalidHeadersLength ErrorKind::InvalidHeadersLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_prelude_checksum"), Error::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_message_checksum"), Error::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length_too_long"), Error::InvalidUtf8String ErrorKind::InvalidUtf8String ); } Loading Loading @@ -735,7 +741,7 @@ impl MessageFrameDecoder { let remaining_len = (&self.prelude[..]) .get_u32() .checked_sub(PRELUDE_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength)?; .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if buffer.remaining() >= remaining_len as usize { return Ok(Some(remaining_len as usize)); } Loading Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +8 −8 Original line number Diff line number Diff line Loading @@ -125,7 +125,7 @@ class EventStreamUnmarshallerGenerator( } rustBlock("value => ") { rustTemplate( "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", "return Err(#{Error}::unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", *codegenScope, ) } Loading Loading @@ -156,7 +156,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) false -> rustTemplate( "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", *codegenScope, ) } Loading Loading @@ -250,7 +250,7 @@ class EventStreamUnmarshallerGenerator( """ let content_type = response_headers.content_type().unwrap_or_default(); if content_type != ${contentType.dq()} { return Err(#{Error}::Unmarshalling(format!( return Err(#{Error}::unmarshalling(format!( "expected :content-type to be '$contentType', but was '{}'", content_type ))) Loading @@ -269,7 +269,7 @@ class EventStreamUnmarshallerGenerator( rustTemplate( """ std::str::from_utf8(message.payload()) .map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))? .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? """, *codegenScope, ) Loading @@ -288,7 +288,7 @@ class EventStreamUnmarshallerGenerator( """ #{parser}(&message.payload()[..]) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall $memberName: {}", err)) })? """, "parser" to parser, Loading Loading @@ -336,7 +336,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; return Ok(#{UnmarshalledMessage}::Error( #{OpError}::new( Loading @@ -360,7 +360,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; """, "parser" to parser, Loading Loading @@ -394,7 +394,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.SERVER -> { rustTemplate( """ return Err(aws_smithy_eventstream::error::Error::Unmarshalling( return Err(aws_smithy_eventstream::error::Error::unmarshalling( format!("unrecognized exception: {}", response_headers.smithy_type.as_str()), )); """, Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +1 −1 Original line number Diff line number Diff line Loading @@ -124,7 +124,7 @@ class EventStreamErrorMarshallerGenerator( rustTemplate( """ $errorName::Unhandled(_inner) => return Err( #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ), """, *codegenScope, Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +2 −2 Original line number Diff line number Diff line Loading @@ -112,7 +112,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ Self::Input::${UnionGenerator.UnknownVariantName} => return Err( #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ) """, *codegenScope, Loading Loading @@ -212,7 +212,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ #{serializerFn}(&$input) .map_err(|err| #{Error}::Marshalling(format!("{}", err)))? .map_err(|err| #{Error}::marshalling(format!("{}", err)))? """, "serializerFn" to serializerFn, *codegenScope, Loading
rust-runtime/aws-smithy-eventstream/src/error.rs +36 −4 Original line number Diff line number Diff line Loading @@ -7,9 +7,8 @@ use aws_smithy_types::DateTime; use std::error::Error as StdError; use std::fmt; #[non_exhaustive] #[derive(Debug)] pub enum Error { pub(crate) enum ErrorKind { HeadersTooLong, HeaderValueTooLong, InvalidHeaderNameLength, Loading @@ -27,12 +26,45 @@ pub enum Error { Unmarshalling(String), } #[derive(Debug)] pub struct Error { kind: ErrorKind, } impl Error { // Used in tests to match on the underlying error kind #[cfg(test)] pub(crate) fn kind(&self) -> &ErrorKind { &self.kind } /// Create an `Error` for failure to marshall a message from a Smithy shape pub fn marshalling(message: impl Into<String>) -> Self { Self { kind: ErrorKind::Marshalling(message.into()), } } /// Create an `Error` for failure to unmarshall a message into a Smithy shape pub fn unmarshalling(message: impl Into<String>) -> Self { Self { kind: ErrorKind::Unmarshalling(message.into()), } } } impl From<ErrorKind> for Error { fn from(kind: ErrorKind) -> Self { Error { kind } } } impl StdError for Error {} impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Error::*; match self { use ErrorKind::*; match &self.kind { HeadersTooLong => write!(f, "headers too long to fit in event stream frame"), HeaderValueTooLong => write!(f, "header value too long to fit in event stream frame"), InvalidHeaderNameLength => write!(f, "invalid header name length"), Loading
rust-runtime/aws-smithy-eventstream/src/frame.rs +47 −41 Original line number Diff line number Diff line Loading @@ -7,7 +7,7 @@ use crate::buf::count::CountBuf; use crate::buf::crc::{CrcBuf, CrcBufMut}; use crate::error::Error; use crate::error::{Error, ErrorKind}; use crate::str_bytes::StrBytes; use bytes::{Buf, BufMut, Bytes}; use std::convert::{TryFrom, TryInto}; Loading Loading @@ -75,7 +75,7 @@ pub trait UnmarshallMessage: fmt::Debug { } mod value { use crate::error::Error; use crate::error::{Error, ErrorKind}; use crate::frame::checked; use crate::str_bytes::StrBytes; use aws_smithy_types::DateTime; Loading Loading @@ -179,7 +179,7 @@ mod value { if $buf.remaining() >= size_of::<$size_typ>() { Ok(HeaderValue::$typ($buf.$read_fn())) } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } }; } Loading @@ -198,18 +198,18 @@ mod value { if buffer.remaining() > size_of::<u16>() { let len = buffer.get_u16() as usize; if buffer.remaining() < len { return Err(Error::InvalidHeaderValue); return Err(ErrorKind::InvalidHeaderValue.into()); } let bytes = buffer.copy_to_bytes(len); if value_type == TYPE_STRING { Ok(HeaderValue::String( bytes.try_into().map_err(|_| Error::InvalidUtf8String)?, bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?, )) } else { Ok(HeaderValue::ByteArray(bytes)) } } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_TIMESTAMP => { Loading @@ -217,11 +217,11 @@ mod value { let epoch_millis = buffer.get_i64(); Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis))) } else { Err(Error::InvalidHeaderValue) Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128), _ => Err(Error::InvalidHeaderValueType(value_type)), _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()), } } Loading @@ -247,19 +247,22 @@ mod value { } ByteArray(val) => { buffer.put_u8(TYPE_BYTE_ARRAY); buffer.put_u16(checked(val.len(), Error::HeaderValueTooLong)?); buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?); buffer.put_slice(&val[..]); } String(val) => { buffer.put_u8(TYPE_STRING); buffer.put_u16(checked(val.as_bytes().len(), Error::HeaderValueTooLong)?); buffer.put_u16(checked( val.as_bytes().len(), ErrorKind::HeaderValueTooLong.into(), )?); buffer.put_slice(&val.as_bytes()[..]); } Timestamp(time) => { buffer.put_u8(TYPE_TIMESTAMP); buffer.put_i64( time.to_millis() .map_err(|_| Error::TimestampValueTooLarge(*time))?, .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?, ); } Uuid(val) => { Loading Loading @@ -329,19 +332,19 @@ impl Header { /// Reads a header from the given `buffer`. fn read_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> { if buffer.remaining() < MIN_HEADER_LEN { return Err(Error::InvalidHeadersLength); return Err(ErrorKind::InvalidHeadersLength.into()); } let mut counting_buf = CountBuf::new(&mut buffer); let name_len = counting_buf.get_u8(); if name_len as usize >= counting_buf.remaining() { return Err(Error::InvalidHeaderNameLength); return Err(ErrorKind::InvalidHeaderNameLength.into()); } let name: StrBytes = counting_buf .copy_to_bytes(name_len as usize) .try_into() .map_err(|_| Error::InvalidUtf8String)?; .map_err(|_| ErrorKind::InvalidUtf8String)?; let value = HeaderValue::read_from(&mut counting_buf)?; Ok((Header::new(name, value), counting_buf.into_count())) } Loading @@ -349,7 +352,7 @@ impl Header { /// Writes the header to the given `buffer`. fn write_to<B: BufMut>(&self, mut buffer: B) -> Result<(), Error> { if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN { return Err(Error::InvalidHeaderNameLength); return Err(ErrorKind::InvalidHeaderNameLength.into()); } buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above")); Loading Loading @@ -414,18 +417,18 @@ impl Message { // If the buffer doesn't have the entire, then error let total_len = crc_buffer.get_u32(); if crc_buffer.remaining() + size_of::<u32>() < total_len as usize { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Validate the prelude let header_len = crc_buffer.get_u32(); let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32()); if expected_crc != prelude_crc { return Err(Error::PreludeChecksumMismatch(expected_crc, prelude_crc)); return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into()); } // The header length can be 0 or >= 2, but must fit within the frame size if header_len == 1 || header_len > max_header_len(total_len)? { return Err(Error::InvalidHeadersLength); return Err(ErrorKind::InvalidHeadersLength.into()); } Ok((total_len, header_len)) } Loading @@ -434,7 +437,7 @@ impl Message { /// the [`MessageFrameDecoder`] instead of this. pub fn read_from<B: Buf>(mut buffer: B) -> Result<Message, Error> { if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Calculate a CRC as we go and read the prelude Loading @@ -444,9 +447,9 @@ impl Message { // Verify we have the full frame before continuing let remaining_len = total_len .checked_sub(PRELUDE_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength)?; .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if crc_buffer.remaining() < remaining_len as usize { return Err(Error::InvalidMessageLength); return Err(ErrorKind::InvalidMessageLength.into()); } // Read headers Loading @@ -456,7 +459,7 @@ impl Message { let (header, bytes_read) = Header::read_from(&mut crc_buffer)?; header_bytes_read += bytes_read; if header_bytes_read > header_len as usize { return Err(Error::InvalidHeaderValue); return Err(ErrorKind::InvalidHeaderValue.into()); } headers.push(header); } Loading @@ -468,7 +471,7 @@ impl Message { let expected_crc = crc_buffer.into_crc(); let message_crc = buffer.get_u32(); if expected_crc != message_crc { return Err(Error::MessageChecksumMismatch(expected_crc, message_crc)); return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into()); } Ok(Message { headers, payload }) Loading @@ -481,8 +484,8 @@ impl Message { header.write_to(&mut headers)?; } let headers_len = checked(headers.len(), Error::HeadersTooLong)?; let payload_len = checked(self.payload.len(), Error::PayloadTooLong)?; let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?; let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?; let message_len = [ PRELUDE_LENGTH_BYTES, headers_len, Loading @@ -491,7 +494,8 @@ impl Message { ] .iter() .try_fold(0u32, |acc, v| { acc.checked_add(*v).ok_or(Error::MessageTooLong) acc.checked_add(*v) .ok_or_else(|| Error::from(ErrorKind::MessageTooLong)) })?; let mut crc_buffer = CrcBufMut::new(buffer); Loading Loading @@ -523,7 +527,7 @@ fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> { fn max_header_len(total_len: u32) -> Result<u32, Error> { total_len .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength) .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> { Loading @@ -531,14 +535,14 @@ fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> { .checked_sub( header_len .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) .ok_or(Error::InvalidHeadersLength)?, .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?, ) .ok_or(Error::InvalidMessageLength) .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } #[cfg(test)] mod message_tests { use crate::error::Error; use crate::error::ErrorKind; use crate::frame::{Header, HeaderValue, Message}; use aws_smithy_types::DateTime; use bytes::Bytes; Loading @@ -546,10 +550,12 @@ mod message_tests { macro_rules! read_message_expect_err { ($bytes:expr, $err:pat) => { let result = Message::read_from(&mut Bytes::from_static($bytes)); let result = result.as_ref(); assert!(result.is_err(), "Expected error, got {:?}", result); assert!( matches!(&result.as_ref(), &Err($err)), matches!(result.err().unwrap().kind(), $err), "Expected {}, got {:?}", stringify!(Err($err)), stringify!($err), result ); }; Loading @@ -559,35 +565,35 @@ mod message_tests { fn invalid_messages() { read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_value_length"), Error::InvalidHeaderValue ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_length_cut_off"), Error::InvalidHeaderValue ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_value_type"), Error::InvalidHeaderValueType(0x60) ErrorKind::InvalidHeaderValueType(0x60) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length"), Error::InvalidHeaderNameLength ErrorKind::InvalidHeaderNameLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_headers_length"), Error::InvalidHeadersLength ErrorKind::InvalidHeadersLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_prelude_checksum"), Error::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_message_checksum"), Error::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length_too_long"), Error::InvalidUtf8String ErrorKind::InvalidUtf8String ); } Loading Loading @@ -735,7 +741,7 @@ impl MessageFrameDecoder { let remaining_len = (&self.prelude[..]) .get_u32() .checked_sub(PRELUDE_LENGTH_BYTES) .ok_or(Error::InvalidMessageLength)?; .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if buffer.remaining() >= remaining_len as usize { return Ok(Some(remaining_len as usize)); } Loading