Unverified Commit b32395bf authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

Feature: checksum related body-wrappers (#1502)

* add: checksum related body-wrappers
add: tests for checksum related body-wrappers

* add: aws-smithy-checksums to CrateSet.key

* add: missing base64::encoded_length fn

* update: use the checksum digest trait's output_size method to get checksum size
add: proptest for base64 encoded length
remove: old base64 encoded length test
remove: incorrect leftover comments and comment quotes

* remove: unnecessary pin

* remove: use of clone during checksum finalization
refactor: checksum fns are no longer fallible
refactor: HttpChecksum is now responsible for reporting size correctly
fix: don't consider trailers as part of ChecksumBody size hint

* remove: Taken error state
update: prefer if-let to match
fix: incorrect is_end_stream impl
refactor: use default trait impl for HttpChecksum::header_value

* fix: clippy lint

* add: header value tests

* undo: accidental retry changes

* add: http-checksum-related errors

* refactor: break up and modularize body wrappers
update: docs
update: ChecksumValidatedBody::new now takes a Box<dyn HttpChecksum>

* fix: broken test

* refactor: rename checksum body structs to be differentiated only by their module
parent 70c6a025
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -11,6 +11,18 @@
# meta = { "breaking" = false, "tada" = false, "bug" = false }
# author = "rcoh"

[[smithy-rs]]
message = "Add checksum calculation and validation wrappers for HTTP bodies."
references = ["smithy-rs#1263"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "Velf"

[[smithy-rs]]
message = "`aws_smithy_http::header::append_merge_header_maps`, a function for merging two `HeaderMap`s, is now public."
references = ["smithy-rs#1263"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "Velf"

[[aws-sdk-rust]]
message = """
Fix compilation of `aws-config` with `rustls` and `native-tls` disabled. The
+7 −2
Original line number Diff line number Diff line
@@ -18,14 +18,19 @@ aws-smithy-types = { path = "../aws-smithy-types" }
bytes = "1"
crc32c = "0.6"
crc32fast = "1.3"
http = "0.2.3"
hex = "0.4.3"
http = "0.2.8"
http-body = "0.4.4"
md-5 = "0.10"
pin-project-lite = "0.2.9"
sha1 = "0.10"
sha2 = "0.10"
tracing = "0.1"

[dev-dependencies]
bytes-utils = "0.1.2"
pretty_assertions = "1.2"
tokio = { version = "1.6", features = ["macros"] }
tokio = { version = "1.6", features = ["macros", "rt"] }
tracing-test = "0.2.1"

[package.metadata.docs.rs]
+9 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! HTTP body-wrappers that calculate and validate checksums.

pub mod calculate;
pub mod validate;
+153 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Functionality for calculating the checksum of an HTTP body and emitting it as trailers.

use crate::http::HttpChecksum;

use aws_smithy_http::body::SdkBody;
use aws_smithy_http::header::append_merge_header_maps;

use http::HeaderMap;
use http_body::SizeHint;
use pin_project_lite::pin_project;

use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
    pub struct ChecksumBody<InnerBody> {
            #[pin]
            body: InnerBody,
            checksum: Option<Box<dyn HttpChecksum>>,
    }
}

impl ChecksumBody<SdkBody> {
    /// Given an `SdkBody` and a `Box<dyn HttpChecksum>`, create a new `ChecksumBody<SdkBody>`.
    pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
        Self {
            body,
            checksum: Some(checksum),
        }
    }
}

impl http_body::Body for ChecksumBody<SdkBody> {
    type Data = bytes::Bytes;
    type Error = aws_smithy_http::body::Error;

    fn poll_data(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
        let this = self.project();
        match this.checksum {
            Some(checksum) => {
                let poll_res = this.body.poll_data(cx);
                if let Poll::Ready(Some(Ok(data))) = &poll_res {
                    checksum.update(data);
                }

                poll_res
            }
            None => unreachable!("This can only fail if poll_data is called again after poll_trailers, which is invalid"),
        }
    }

    fn poll_trailers(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
        let this = self.project();
        let poll_res = this.body.poll_trailers(cx);

        if let Poll::Ready(Ok(maybe_inner_trailers)) = poll_res {
            let checksum_headers = if let Some(checksum) = this.checksum.take() {
                checksum.headers()
            } else {
                return Poll::Ready(Ok(None));
            };

            return match maybe_inner_trailers {
                Some(inner_trailers) => Poll::Ready(Ok(Some(append_merge_header_maps(
                    inner_trailers,
                    checksum_headers,
                )))),
                None => Poll::Ready(Ok(Some(checksum_headers))),
            };
        }

        poll_res
    }

    fn is_end_stream(&self) -> bool {
        // If inner body is finished and we've already consumed the checksum then we must be
        // at the end of the stream.
        self.body.is_end_stream() && self.checksum.is_none()
    }

    fn size_hint(&self) -> SizeHint {
        self.body.size_hint()
    }
}

#[cfg(test)]
mod tests {
    use super::ChecksumBody;
    use crate::http::new_from_algorithm;
    use crate::http::{CRC_32_HEADER_NAME, CRC_32_NAME};
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_types::base64;
    use bytes::Buf;
    use bytes_utils::SegmentedBuf;
    use http_body::Body;
    use std::io::Read;

    fn header_value_as_checksum_string(header_value: &http::HeaderValue) -> String {
        let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
        let decoded_checksum = decoded_checksum
            .into_iter()
            .map(|byte| format!("{:02X?}", byte))
            .collect::<String>();

        format!("0x{}", decoded_checksum)
    }

    #[tokio::test]
    async fn test_checksum_body() {
        let input_text = "This is some test text for an SdkBody";
        let body = SdkBody::from(input_text);
        let checksum = new_from_algorithm(CRC_32_NAME).unwrap();
        let mut body = ChecksumBody::new(body, checksum);

        let mut output = SegmentedBuf::new();
        while let Some(buf) = body.data().await {
            output.push(buf.unwrap());
        }

        let mut output_text = String::new();
        output
            .reader()
            .read_to_string(&mut output_text)
            .expect("Doesn't cause IO errors");
        // Verify data is complete and unaltered
        assert_eq!(input_text, output_text);

        let trailers = body
            .trailers()
            .await
            .expect("checksum generation was without error")
            .expect("trailers were set");
        let checksum_trailer = trailers
            .get(&CRC_32_HEADER_NAME)
            .expect("trailers contain crc32 checksum");
        let checksum_trailer = header_value_as_checksum_string(checksum_trailer);

        // Known correct checksum for the input "This is some test text for an SdkBody"
        assert_eq!("0x99B01F72", checksum_trailer);
    }
}
+222 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Functionality for validating an HTTP body against a given precalculated checksum and emitting an
//! error if it doesn't match.

use crate::http::HttpChecksum;

use aws_smithy_http::body::SdkBody;

use bytes::Bytes;
use http::{HeaderMap, HeaderValue};
use http_body::SizeHint;
use pin_project_lite::pin_project;

use std::fmt::Display;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit an error if it
    /// doesn't match the precalculated checksum.
    pub struct ChecksumBody<InnerBody> {
        #[pin]
        inner: InnerBody,
        checksum: Option<Box<dyn HttpChecksum>>,
        precalculated_checksum: Bytes,
    }
}

impl ChecksumBody<SdkBody> {
    /// Given an `SdkBody`, a `Box<dyn HttpChecksum>`, and a precalculated checksum represented
    /// as `Bytes`, create a new `ChecksumBody<SdkBody>`.
    pub fn new(
        body: SdkBody,
        checksum: Box<dyn HttpChecksum>,
        precalculated_checksum: Bytes,
    ) -> Result<Self, Box<dyn std::error::Error>> {
        Ok(Self {
            inner: body,
            checksum: Some(checksum),
            precalculated_checksum,
        })
    }

    fn poll_inner(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Bytes, aws_smithy_http::body::Error>>> {
        use http_body::Body;

        let this = self.project();
        let checksum = this.checksum;

        match this.inner.poll_data(cx) {
            Poll::Ready(Some(Ok(data))) => {
                tracing::trace!(
                    "reading {} bytes from the body and updating the checksum calculation",
                    data.len()
                );
                let checksum = match checksum.as_mut() {
                    Some(checksum) => checksum,
                    None => {
                        unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
                    }
                };

                checksum.update(&data);
                Poll::Ready(Some(Ok(data)))
            }
            // Once the inner body has stopped returning data, check the checksum
            // and return an error if it doesn't match.
            Poll::Ready(None) => {
                tracing::trace!("finished reading from body, calculating final checksum");
                let checksum = match checksum.take() {
                    Some(checksum) => checksum,
                    None => {
                        // If the checksum was already taken and this was polled again anyways,
                        // then return nothing
                        return Poll::Ready(None);
                    }
                };

                let actual_checksum = checksum.finalize();
                if *this.precalculated_checksum == actual_checksum {
                    Poll::Ready(None)
                } else {
                    // So many parens it's starting to look like LISP
                    Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
                        expected: this.precalculated_checksum.clone(),
                        actual: actual_checksum,
                    }))))
                }
            }
            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
            Poll::Pending => Poll::Pending,
        }
    }
}

/// Errors related to checksum calculation and validation
#[derive(Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Error {
    /// The actual checksum didn't match the expected checksum. The checksummed data has been
    /// altered since the expected checksum was calculated.
    ChecksumMismatch { expected: Bytes, actual: Bytes },
}

impl Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            Error::ChecksumMismatch { expected, actual } => write!(
                f,
                "body checksum mismatch. expected body checksum to be {} but it was {}",
                hex::encode(expected),
                hex::encode(actual)
            ),
        }
    }
}

impl std::error::Error for Error {}

impl http_body::Body for ChecksumBody<SdkBody> {
    type Data = Bytes;
    type Error = aws_smithy_http::body::Error;

    fn poll_data(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
        self.poll_inner(cx)
    }

    fn poll_trailers(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
        self.project().inner.poll_trailers(cx)
    }

    fn is_end_stream(&self) -> bool {
        self.checksum.is_none()
    }

    fn size_hint(&self) -> SizeHint {
        self.inner.size_hint()
    }
}

#[cfg(test)]
mod tests {
    use crate::body::validate::{ChecksumBody, Error};
    use crate::http::new_from_algorithm;
    use aws_smithy_http::body::SdkBody;
    use bytes::{Buf, Bytes};
    use bytes_utils::SegmentedBuf;
    use http_body::Body;
    use std::io::Read;

    fn calculate_crc32_checksum(input: &str) -> Bytes {
        let checksum = crc32fast::hash(input.as_bytes());
        Bytes::copy_from_slice(&checksum.to_be_bytes())
    }

    #[tokio::test]
    async fn test_checksum_validated_body_errors_on_mismatch() {
        let input_text = "This is some test text for an SdkBody";
        let actual_checksum = calculate_crc32_checksum(input_text);
        let body = SdkBody::from(input_text);
        let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
        let mut body = ChecksumBody::new(
            body,
            new_from_algorithm("crc32").unwrap(),
            non_matching_checksum.clone(),
        )
        .unwrap();

        while let Some(data) = body.data().await {
            match data {
                Ok(_) => { /* Do nothing */ }
                Err(e) => {
                    match e.downcast_ref::<Error>().unwrap() {
                        Error::ChecksumMismatch { expected, actual } => {
                            assert_eq!(expected, &non_matching_checksum);
                            assert_eq!(actual, &actual_checksum);
                        }
                    }

                    return;
                }
            }
        }

        panic!("didn't hit expected error condition");
    }

    #[tokio::test]
    async fn test_checksum_validated_body_succeeds_on_match() {
        let input_text = "This is some test text for an SdkBody";
        let actual_checksum = calculate_crc32_checksum(input_text);
        let body = SdkBody::from(input_text);
        let mut body =
            ChecksumBody::new(body, new_from_algorithm("crc32").unwrap(), actual_checksum).unwrap();

        let mut output = SegmentedBuf::new();
        while let Some(buf) = body.data().await {
            output.push(buf.unwrap());
        }

        let mut output_text = String::new();
        output
            .reader()
            .read_to_string(&mut output_text)
            .expect("Doesn't cause IO errors");
        // Verify data is complete and unaltered
        assert_eq!(input_text, output_text);
    }
}
Loading