Unverified Commit 12f217b3 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Revamp errors in `aws-smithy-eventstream` (#1873)

parent 6aef53a0
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -125,7 +125,7 @@ class EventStreamUnmarshallerGenerator(
                    }
                    rustBlock("value => ") {
                        rustTemplate(
                            "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));",
                            "return Err(#{Error}::unmarshalling(format!(\"unrecognized :message-type: {}\", value)));",
                            *codegenScope,
                        )
                    }
@@ -156,7 +156,7 @@ class EventStreamUnmarshallerGenerator(
                        *codegenScope,
                    )
                    false -> rustTemplate(
                        "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));",
                        "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));",
                        *codegenScope,
                    )
                }
@@ -250,7 +250,7 @@ class EventStreamUnmarshallerGenerator(
                """
                let content_type = response_headers.content_type().unwrap_or_default();
                if content_type != ${contentType.dq()} {
                    return Err(#{Error}::Unmarshalling(format!(
                    return Err(#{Error}::unmarshalling(format!(
                        "expected :content-type to be '$contentType', but was '{}'",
                        content_type
                    )))
@@ -269,7 +269,7 @@ class EventStreamUnmarshallerGenerator(
                    rustTemplate(
                        """
                        std::str::from_utf8(message.payload())
                            .map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))?
                            .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))?
                        """,
                        *codegenScope,
                    )
@@ -288,7 +288,7 @@ class EventStreamUnmarshallerGenerator(
            """
            #{parser}(&message.payload()[..])
                .map_err(|err| {
                    #{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err))
                    #{Error}::unmarshalling(format!("failed to unmarshall $memberName: {}", err))
                })?
            """,
            "parser" to parser,
@@ -336,7 +336,7 @@ class EventStreamUnmarshallerGenerator(
                                    """
                                    builder = #{parser}(&message.payload()[..], builder)
                                        .map_err(|err| {
                                            #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
                                            #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
                                        })?;
                                    return Ok(#{UnmarshalledMessage}::Error(
                                        #{OpError}::new(
@@ -360,7 +360,7 @@ class EventStreamUnmarshallerGenerator(
                                    """
                                    builder = #{parser}(&message.payload()[..], builder)
                                        .map_err(|err| {
                                            #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
                                            #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err))
                                        })?;
                                    """,
                                    "parser" to parser,
@@ -394,7 +394,7 @@ class EventStreamUnmarshallerGenerator(
            CodegenTarget.SERVER -> {
                rustTemplate(
                    """
                    return Err(aws_smithy_eventstream::error::Error::Unmarshalling(
                    return Err(aws_smithy_eventstream::error::Error::unmarshalling(
                    format!("unrecognized exception: {}", response_headers.smithy_type.as_str()),
                    ));
                    """,
+1 −1
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ class EventStreamErrorMarshallerGenerator(
                            rustTemplate(
                                """
                                $errorName::Unhandled(_inner) => return Err(
                                    #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned())
                                    #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned())
                                ),
                                """,
                                *codegenScope,
+2 −2
Original line number Diff line number Diff line
@@ -112,7 +112,7 @@ open class EventStreamMarshallerGenerator(
                        rustTemplate(
                            """
                            Self::Input::${UnionGenerator.UnknownVariantName} => return Err(
                                #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned())
                                #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned())
                            )
                            """,
                            *codegenScope,
@@ -212,7 +212,7 @@ open class EventStreamMarshallerGenerator(
                    rustTemplate(
                        """
                        #{serializerFn}(&$input)
                            .map_err(|err| #{Error}::Marshalling(format!("{}", err)))?
                            .map_err(|err| #{Error}::marshalling(format!("{}", err)))?
                        """,
                        "serializerFn" to serializerFn,
                        *codegenScope,
+36 −4
Original line number Diff line number Diff line
@@ -7,9 +7,8 @@ use aws_smithy_types::DateTime;
use std::error::Error as StdError;
use std::fmt;

#[non_exhaustive]
#[derive(Debug)]
pub enum Error {
pub(crate) enum ErrorKind {
    HeadersTooLong,
    HeaderValueTooLong,
    InvalidHeaderNameLength,
@@ -27,12 +26,45 @@ pub enum Error {
    Unmarshalling(String),
}

#[derive(Debug)]
pub struct Error {
    kind: ErrorKind,
}

impl Error {
    // Used in tests to match on the underlying error kind
    #[cfg(test)]
    pub(crate) fn kind(&self) -> &ErrorKind {
        &self.kind
    }

    /// Create an `Error` for failure to marshall a message from a Smithy shape
    pub fn marshalling(message: impl Into<String>) -> Self {
        Self {
            kind: ErrorKind::Marshalling(message.into()),
        }
    }

    /// Create an `Error` for failure to unmarshall a message into a Smithy shape
    pub fn unmarshalling(message: impl Into<String>) -> Self {
        Self {
            kind: ErrorKind::Unmarshalling(message.into()),
        }
    }
}

impl From<ErrorKind> for Error {
    fn from(kind: ErrorKind) -> Self {
        Error { kind }
    }
}

impl StdError for Error {}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        use Error::*;
        match self {
        use ErrorKind::*;
        match &self.kind {
            HeadersTooLong => write!(f, "headers too long to fit in event stream frame"),
            HeaderValueTooLong => write!(f, "header value too long to fit in event stream frame"),
            InvalidHeaderNameLength => write!(f, "invalid header name length"),
+47 −41
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@

use crate::buf::count::CountBuf;
use crate::buf::crc::{CrcBuf, CrcBufMut};
use crate::error::Error;
use crate::error::{Error, ErrorKind};
use crate::str_bytes::StrBytes;
use bytes::{Buf, BufMut, Bytes};
use std::convert::{TryFrom, TryInto};
@@ -75,7 +75,7 @@ pub trait UnmarshallMessage: fmt::Debug {
}

mod value {
    use crate::error::Error;
    use crate::error::{Error, ErrorKind};
    use crate::frame::checked;
    use crate::str_bytes::StrBytes;
    use aws_smithy_types::DateTime;
@@ -179,7 +179,7 @@ mod value {
            if $buf.remaining() >= size_of::<$size_typ>() {
                Ok(HeaderValue::$typ($buf.$read_fn()))
            } else {
                Err(Error::InvalidHeaderValue)
                Err(ErrorKind::InvalidHeaderValue.into())
            }
        };
    }
@@ -198,18 +198,18 @@ mod value {
                    if buffer.remaining() > size_of::<u16>() {
                        let len = buffer.get_u16() as usize;
                        if buffer.remaining() < len {
                            return Err(Error::InvalidHeaderValue);
                            return Err(ErrorKind::InvalidHeaderValue.into());
                        }
                        let bytes = buffer.copy_to_bytes(len);
                        if value_type == TYPE_STRING {
                            Ok(HeaderValue::String(
                                bytes.try_into().map_err(|_| Error::InvalidUtf8String)?,
                                bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
                            ))
                        } else {
                            Ok(HeaderValue::ByteArray(bytes))
                        }
                    } else {
                        Err(Error::InvalidHeaderValue)
                        Err(ErrorKind::InvalidHeaderValue.into())
                    }
                }
                TYPE_TIMESTAMP => {
@@ -217,11 +217,11 @@ mod value {
                        let epoch_millis = buffer.get_i64();
                        Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
                    } else {
                        Err(Error::InvalidHeaderValue)
                        Err(ErrorKind::InvalidHeaderValue.into())
                    }
                }
                TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
                _ => Err(Error::InvalidHeaderValueType(value_type)),
                _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
            }
        }

@@ -247,19 +247,22 @@ mod value {
                }
                ByteArray(val) => {
                    buffer.put_u8(TYPE_BYTE_ARRAY);
                    buffer.put_u16(checked(val.len(), Error::HeaderValueTooLong)?);
                    buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
                    buffer.put_slice(&val[..]);
                }
                String(val) => {
                    buffer.put_u8(TYPE_STRING);
                    buffer.put_u16(checked(val.as_bytes().len(), Error::HeaderValueTooLong)?);
                    buffer.put_u16(checked(
                        val.as_bytes().len(),
                        ErrorKind::HeaderValueTooLong.into(),
                    )?);
                    buffer.put_slice(&val.as_bytes()[..]);
                }
                Timestamp(time) => {
                    buffer.put_u8(TYPE_TIMESTAMP);
                    buffer.put_i64(
                        time.to_millis()
                            .map_err(|_| Error::TimestampValueTooLarge(*time))?,
                            .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
                    );
                }
                Uuid(val) => {
@@ -329,19 +332,19 @@ impl Header {
    /// Reads a header from the given `buffer`.
    fn read_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
        if buffer.remaining() < MIN_HEADER_LEN {
            return Err(Error::InvalidHeadersLength);
            return Err(ErrorKind::InvalidHeadersLength.into());
        }

        let mut counting_buf = CountBuf::new(&mut buffer);
        let name_len = counting_buf.get_u8();
        if name_len as usize >= counting_buf.remaining() {
            return Err(Error::InvalidHeaderNameLength);
            return Err(ErrorKind::InvalidHeaderNameLength.into());
        }

        let name: StrBytes = counting_buf
            .copy_to_bytes(name_len as usize)
            .try_into()
            .map_err(|_| Error::InvalidUtf8String)?;
            .map_err(|_| ErrorKind::InvalidUtf8String)?;
        let value = HeaderValue::read_from(&mut counting_buf)?;
        Ok((Header::new(name, value), counting_buf.into_count()))
    }
@@ -349,7 +352,7 @@ impl Header {
    /// Writes the header to the given `buffer`.
    fn write_to<B: BufMut>(&self, mut buffer: B) -> Result<(), Error> {
        if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN {
            return Err(Error::InvalidHeaderNameLength);
            return Err(ErrorKind::InvalidHeaderNameLength.into());
        }

        buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above"));
@@ -414,18 +417,18 @@ impl Message {
        // If the buffer doesn't have the entire, then error
        let total_len = crc_buffer.get_u32();
        if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
            return Err(Error::InvalidMessageLength);
            return Err(ErrorKind::InvalidMessageLength.into());
        }

        // Validate the prelude
        let header_len = crc_buffer.get_u32();
        let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
        if expected_crc != prelude_crc {
            return Err(Error::PreludeChecksumMismatch(expected_crc, prelude_crc));
            return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
        }
        // The header length can be 0 or >= 2, but must fit within the frame size
        if header_len == 1 || header_len > max_header_len(total_len)? {
            return Err(Error::InvalidHeadersLength);
            return Err(ErrorKind::InvalidHeadersLength.into());
        }
        Ok((total_len, header_len))
    }
@@ -434,7 +437,7 @@ impl Message {
    /// the [`MessageFrameDecoder`] instead of this.
    pub fn read_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
        if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
            return Err(Error::InvalidMessageLength);
            return Err(ErrorKind::InvalidMessageLength.into());
        }

        // Calculate a CRC as we go and read the prelude
@@ -444,9 +447,9 @@ impl Message {
        // Verify we have the full frame before continuing
        let remaining_len = total_len
            .checked_sub(PRELUDE_LENGTH_BYTES)
            .ok_or(Error::InvalidMessageLength)?;
            .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
        if crc_buffer.remaining() < remaining_len as usize {
            return Err(Error::InvalidMessageLength);
            return Err(ErrorKind::InvalidMessageLength.into());
        }

        // Read headers
@@ -456,7 +459,7 @@ impl Message {
            let (header, bytes_read) = Header::read_from(&mut crc_buffer)?;
            header_bytes_read += bytes_read;
            if header_bytes_read > header_len as usize {
                return Err(Error::InvalidHeaderValue);
                return Err(ErrorKind::InvalidHeaderValue.into());
            }
            headers.push(header);
        }
@@ -468,7 +471,7 @@ impl Message {
        let expected_crc = crc_buffer.into_crc();
        let message_crc = buffer.get_u32();
        if expected_crc != message_crc {
            return Err(Error::MessageChecksumMismatch(expected_crc, message_crc));
            return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
        }

        Ok(Message { headers, payload })
@@ -481,8 +484,8 @@ impl Message {
            header.write_to(&mut headers)?;
        }

        let headers_len = checked(headers.len(), Error::HeadersTooLong)?;
        let payload_len = checked(self.payload.len(), Error::PayloadTooLong)?;
        let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
        let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?;
        let message_len = [
            PRELUDE_LENGTH_BYTES,
            headers_len,
@@ -491,7 +494,8 @@ impl Message {
        ]
        .iter()
        .try_fold(0u32, |acc, v| {
            acc.checked_add(*v).ok_or(Error::MessageTooLong)
            acc.checked_add(*v)
                .ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
        })?;

        let mut crc_buffer = CrcBufMut::new(buffer);
@@ -523,7 +527,7 @@ fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
fn max_header_len(total_len: u32) -> Result<u32, Error> {
    total_len
        .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
        .ok_or(Error::InvalidMessageLength)
        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
}

fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
@@ -531,14 +535,14 @@ fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
        .checked_sub(
            header_len
                .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
                .ok_or(Error::InvalidHeadersLength)?,
                .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
        )
        .ok_or(Error::InvalidMessageLength)
        .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
}

#[cfg(test)]
mod message_tests {
    use crate::error::Error;
    use crate::error::ErrorKind;
    use crate::frame::{Header, HeaderValue, Message};
    use aws_smithy_types::DateTime;
    use bytes::Bytes;
@@ -546,10 +550,12 @@ mod message_tests {
    macro_rules! read_message_expect_err {
        ($bytes:expr, $err:pat) => {
            let result = Message::read_from(&mut Bytes::from_static($bytes));
            let result = result.as_ref();
            assert!(result.is_err(), "Expected error, got {:?}", result);
            assert!(
                matches!(&result.as_ref(), &Err($err)),
                matches!(result.err().unwrap().kind(), $err),
                "Expected {}, got {:?}",
                stringify!(Err($err)),
                stringify!($err),
                result
            );
        };
@@ -559,35 +565,35 @@ mod message_tests {
    fn invalid_messages() {
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_header_string_value_length"),
            Error::InvalidHeaderValue
            ErrorKind::InvalidHeaderValue
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_header_string_length_cut_off"),
            Error::InvalidHeaderValue
            ErrorKind::InvalidHeaderValue
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_header_value_type"),
            Error::InvalidHeaderValueType(0x60)
            ErrorKind::InvalidHeaderValueType(0x60)
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_header_name_length"),
            Error::InvalidHeaderNameLength
            ErrorKind::InvalidHeaderNameLength
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_headers_length"),
            Error::InvalidHeadersLength
            ErrorKind::InvalidHeadersLength
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_prelude_checksum"),
            Error::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
            ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_message_checksum"),
            Error::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
            ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
        );
        read_message_expect_err!(
            include_bytes!("../test_data/invalid_header_name_length_too_long"),
            Error::InvalidUtf8String
            ErrorKind::InvalidUtf8String
        );
    }

@@ -735,7 +741,7 @@ impl MessageFrameDecoder {
            let remaining_len = (&self.prelude[..])
                .get_u32()
                .checked_sub(PRELUDE_LENGTH_BYTES)
                .ok_or(Error::InvalidMessageLength)?;
                .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
            if buffer.remaining() >= remaining_len as usize {
                return Ok(Some(remaining_len as usize));
            }
Loading