Unverified Commit 18fc6923 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Remove client body callbacks (#2065)

parent 3ecefff7
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -625,3 +625,15 @@ message = "Implementation of the Debug trait for container shapes now redacts wh
references = ["smithy-rs#1983", "smithy-rs#2029"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "`SdkBody` callbacks have been removed. If you were using these, please [file an issue](https://github.com/awslabs/aws-sdk-rust/issues/new) so that we can better understand your use-case and provide the support you need."
references = ["smithy-rs#2065"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "`SdkBody` callbacks have been removed. If you were using these, please [file an issue](https://github.com/awslabs/smithy-rs/issues/new) so that we can better understand your use-case and provide the support you need."
references = ["smithy-rs#2065"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"}
author = "jdisanti"
+3 −68
Original line number Diff line number Diff line
@@ -13,9 +13,6 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::callback::BodyCallback;
use crate::header::append_merge_header_maps;

pub type Error = Box<dyn StdError + Send + Sync>;

pin_project! {
@@ -35,9 +32,6 @@ pin_project! {
        // In the event of retry, this function will be called to generate a new body. See
        // [`try_clone()`](SdkBody::try_clone)
        rebuild: Option<Arc<dyn (Fn() -> Inner) + Send + Sync>>,
        // A list of callbacks that will be called at various points of this `SdkBody`'s lifecycle
        #[pin]
        callbacks: Vec<Box<dyn BodyCallback>>,
    }
}

@@ -96,7 +90,6 @@ impl SdkBody {
        Self {
            inner: Inner::Dyn { inner: body },
            rebuild: None,
            callbacks: Vec::new(),
        }
    }

@@ -113,7 +106,6 @@ impl SdkBody {
        SdkBody {
            inner: initial.inner,
            rebuild: Some(Arc::new(move || f().inner)),
            callbacks: Vec::new(),
        }
    }

@@ -121,7 +113,6 @@ impl SdkBody {
        Self {
            inner: Inner::Taken,
            rebuild: None,
            callbacks: Vec::new(),
        }
    }

@@ -129,7 +120,6 @@ impl SdkBody {
        Self {
            inner: Inner::Once { inner: None },
            rebuild: Some(Arc::new(|| Inner::Once { inner: None })),
            callbacks: Vec::new(),
        }
    }

@@ -137,8 +127,8 @@ impl SdkBody {
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Bytes, Error>>> {
        let mut this = self.project();
        let polling_result = match this.inner.project() {
        let this = self.project();
        match this.inner.project() {
            InnerProj::Once { ref mut inner } => {
                let data = inner.take();
                match data {
@@ -152,30 +142,8 @@ impl SdkBody {
            InnerProj::Taken => {
                Poll::Ready(Some(Err("A `Taken` body should never be polled".into())))
            }
        };

        match &polling_result {
            // When we get some bytes back from polling, pass those bytes to each callback in turn
            Poll::Ready(Some(Ok(bytes))) => {
                for callback in this.callbacks.iter_mut() {
                    // Callbacks can run into errors when reading bytes. They'll be surfaced here
                    callback.update(bytes)?;
                }
            }
            // When we're done polling for bytes, run each callback's `trailers()` method. If any calls to
            // `trailers()` return an error, propagate that error up. Otherwise, continue.
            Poll::Ready(None) => {
                for callback_result in this.callbacks.iter().map(BodyCallback::trailers) {
                    if let Err(e) = callback_result {
                        return Poll::Ready(Some(Err(e)));
                    }
        }
    }
            _ => (),
        }

        polling_result
    }

    /// If possible, return a reference to this body as `&[u8]`
    ///
@@ -192,12 +160,9 @@ impl SdkBody {
    pub fn try_clone(&self) -> Option<Self> {
        self.rebuild.as_ref().map(|rebuild| {
            let next = rebuild();
            let callbacks = self.callbacks.iter().map(BodyCallback::make_new).collect();

            Self {
                inner: next,
                rebuild: self.rebuild.clone(),
                callbacks,
            }
        })
    }
@@ -206,11 +171,6 @@ impl SdkBody {
        http_body::Body::size_hint(self).exact()
    }

    pub fn with_callback(&mut self, callback: Box<dyn BodyCallback>) -> &mut Self {
        self.callbacks.push(callback);
        self
    }

    pub fn map(self, f: impl Fn(SdkBody) -> SdkBody + Sync + Send + 'static) -> SdkBody {
        if self.rebuild.is_some() {
            SdkBody::retryable(move || f(self.try_clone().unwrap()))
@@ -235,7 +195,6 @@ impl From<Bytes> for SdkBody {
            rebuild: Some(Arc::new(move || Inner::Once {
                inner: Some(bytes.clone()),
            })),
            callbacks: Vec::new(),
        }
    }
}
@@ -245,7 +204,6 @@ impl From<hyper::Body> for SdkBody {
        SdkBody {
            inner: Inner::Streaming { inner: body },
            rebuild: None,
            callbacks: Vec::new(),
        }
    }
}
@@ -283,30 +241,7 @@ impl http_body::Body for SdkBody {
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
        let mut header_map = None;
        // Iterate over all callbacks, checking each for any `HeaderMap`s
        for callback in &self.callbacks {
            match callback.trailers() {
                // If this is the first `HeaderMap` we've encountered, save it
                Ok(Some(right_header_map)) if header_map.is_none() => {
                    header_map = Some(right_header_map);
                }
                // If this is **not** the first `HeaderMap` we've encountered, merge it
                Ok(Some(right_header_map)) if header_map.is_some() => {
                    header_map = Some(append_merge_header_maps(
                        header_map.unwrap(),
                        right_header_map,
                    ));
                }
                // Early return if a callback encountered an error.
                Err(e) => {
                    return Poll::Ready(Err(e));
                }
                // Otherwise, continue on to the next iteration of the loop.
                _ => continue,
            }
        }
        Poll::Ready(Ok(header_map))
        Poll::Ready(Ok(None))
    }

    fn is_end_stream(&self) -> bool {
+0 −16
Original line number Diff line number Diff line
@@ -123,7 +123,6 @@

use crate::body::SdkBody;
use crate::byte_stream::error::Error;
use crate::callback::BodyCallback;
use bytes::Buf;
use bytes::Bytes;
use bytes_utils::SegmentedBuf;
@@ -374,14 +373,6 @@ impl ByteStream {
        FsBuilder::new().file(file).build().await
    }

    /// Set a callback on this `ByteStream`. The callback's methods will be called at various points
    /// throughout this `ByteStream`'s life cycle. See the [`BodyCallback`](BodyCallback) trait for
    /// more information.
    pub fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
        self.inner.with_body_callback(body_callback);
        self
    }

    #[cfg(feature = "rt-tokio")]
    /// Convert this `ByteStream` into a struct that implements [`AsyncRead`](tokio::io::AsyncRead).
    ///
@@ -542,13 +533,6 @@ impl<B> Inner<B> {
    }
}

impl Inner<SdkBody> {
    fn with_body_callback(&mut self, body_callback: Box<dyn BodyCallback>) -> &mut Self {
        self.body.with_callback(body_callback);
        self
    }
}

const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#"
You're running a 32-bit system and this stream's length is too large to be represented with a usize.
Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture.
+0 −172
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! A module for traits that define callbacks that will be called at specific points in an HTTP request's lifecycle.

use http::{HeaderMap, HeaderValue};

type BoxError = Box<dyn std::error::Error + Send + Sync>;

/// A callback that, when inserted into a request body, will be called for corresponding lifecycle events.
pub trait BodyCallback: Send + Sync {
    /// This lifecycle function is called for each chunk **successfully** read. If an error occurs while reading a chunk,
    /// this method will not be called. This method takes `&mut self` so that implementors may modify an implementing
    /// struct/enum's internal state. Implementors may return an error.
    fn update(&mut self, bytes: &[u8]) -> Result<(), BoxError> {
        // "Use" bytes so that the compiler won't complain.
        let _ = bytes;
        Ok(())
    }

    /// This callback is called once all chunks have been read. If the callback encountered one or more errors
    /// while running `update`s, this is how those errors are raised. Implementors may return a [`HeaderMap`][HeaderMap]
    /// that will be appended to the HTTP body as a trailer. This is only useful to do for streaming requests.
    fn trailers(&self) -> Result<Option<HeaderMap<HeaderValue>>, BoxError> {
        Ok(None)
    }

    /// Create a new `BodyCallback` from an existing one. This is called when a `BodyCallback` needs to be
    /// re-initialized with default state. For example: when a request has a body that needs to be
    /// rebuilt, all callbacks for that body need to be run again but with a fresh internal state.
    fn make_new(&self) -> Box<dyn BodyCallback>;
}

impl BodyCallback for Box<dyn BodyCallback> {
    fn update(&mut self, bytes: &[u8]) -> Result<(), BoxError> {
        self.as_mut().update(bytes)
    }
    fn trailers(&self) -> Result<Option<HeaderMap<HeaderValue>>, BoxError> {
        self.as_ref().trailers()
    }
    fn make_new(&self) -> Box<dyn BodyCallback> {
        self.as_ref().make_new()
    }
}

#[cfg(test)]
mod tests {
    use super::{BodyCallback, BoxError};
    use crate::body::SdkBody;
    use crate::byte_stream::ByteStream;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::sync::Arc;

    #[tracing_test::traced_test]
    #[tokio::test]
    async fn callbacks_are_called_for_update() {
        struct CallbackA;
        struct CallbackB;

        impl BodyCallback for CallbackA {
            fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
                tracing::debug!("callback A was called");

                Ok(())
            }

            fn make_new(&self) -> Box<dyn BodyCallback> {
                Box::new(Self)
            }
        }

        impl BodyCallback for CallbackB {
            fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
                tracing::debug!("callback B was called");

                Ok(())
            }

            fn make_new(&self) -> Box<dyn BodyCallback> {
                Box::new(Self)
            }
        }

        let mut body = SdkBody::from("test");
        body.with_callback(Box::new(CallbackA))
            .with_callback(Box::new(CallbackB));

        let body = ByteStream::from(body).collect().await.unwrap().into_bytes();
        let body = std::str::from_utf8(&body).unwrap();

        // Assert that the body that went in is the same as the body coming out.
        assert_eq!(body, "test");

        // Test that all callbacks were called.
        assert!(logs_contain("callback A was called"));
        assert!(logs_contain("callback B was called"));
    }

    struct TestCallback {
        times_called: Arc<AtomicUsize>,
    }

    impl BodyCallback for TestCallback {
        fn update(&mut self, _bytes: &[u8]) -> Result<(), BoxError> {
            self.times_called.fetch_add(1, Ordering::SeqCst);

            Ok(())
        }

        fn make_new(&self) -> Box<dyn BodyCallback> {
            Box::new(Self {
                times_called: Arc::new(AtomicUsize::new(0)),
            })
        }
    }

    #[tokio::test]
    async fn callback_for_buffered_body_is_called_once() {
        let times_called = Arc::new(AtomicUsize::new(0));
        let test_text: String = (0..=1000)
            .into_iter()
            .map(|n| format!("line {}\n", n))
            .collect();

        {
            let mut body = SdkBody::from(test_text);
            let callback = TestCallback {
                times_called: times_called.clone(),
            };
            body.with_callback(Box::new(callback));
            let _body = ByteStream::new(body).collect().await.unwrap().into_bytes();
        }

        let times_called = Arc::try_unwrap(times_called).unwrap();
        let times_called = times_called.into_inner();

        // Callback only gets called once because it's not a streaming body
        assert_eq!(times_called, 1);
    }

    #[tracing_test::traced_test]
    #[tokio::test]
    async fn callback_for_streaming_body_is_called_per_chunk() {
        // Include a large body of text for testing
        let times_called = Arc::new(AtomicUsize::new(0));

        {
            let test_stream = tokio_stream::iter(
                (1..=1000)
                    .into_iter()
                    .map(|n| -> Result<String, std::io::Error> { Ok(format!("line {}\n", n)) }),
            );
            let mut body = SdkBody::from(hyper::body::Body::wrap_stream(test_stream));
            tracing::trace!("{:?}", body);
            assert!(logs_contain("Streaming(Body(Streaming))"));

            let callback = TestCallback {
                times_called: times_called.clone(),
            };
            body.with_callback(Box::new(callback));
            let _body = ByteStream::new(body).collect().await.unwrap().into_bytes();
        }

        let times_called = Arc::try_unwrap(times_called).unwrap();
        let times_called = times_called.into_inner();

        // Callback is called once per chunk
        assert_eq!(times_called, 1000);
    }
}
+0 −1
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@
#![cfg_attr(docsrs, feature(doc_cfg))]

pub mod body;
pub mod callback;
pub mod endpoint;
pub mod header;
pub mod http_versions;