Unverified Commit 2e9ade0e authored by John DiSanti's avatar John DiSanti
Browse files

Reintroduce overhauled stalled stream protection and upload support

This commit unreverts 27834ae2 (#3485).
Original commit message:

Overhaul stalled stream protection and add upload support (#3485)

This PR overhauls the existing stalled stream protection with a new
algorithm, and also adds support for minimum throughput on upload
streams. The new algorithm adds support for differentiating between the
user or the server causing the stall, and not timing out if it's the
user causing the stall. This will fix timeout issues when a customer
makes remote service calls in between streaming pieces of information.
parent 170df733
Loading
Loading
Loading
Loading
+43 −1
Original line number Diff line number Diff line
@@ -10,3 +10,45 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[smithy-rs]]
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:

```rust
let config = my_service::Config::builder()
    .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
    // ...
    .build();
```
"""
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
authors = ["jdisanti"]

[[aws-sdk-rust]]
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:

```rust
let config = aws_config::defaults(BehaviorVersion::latest())
    .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
    .load()
    .await;
```
"""
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit."
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
authors = ["jdisanti"]

[[aws-sdk-rust]]
message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit."
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "jdisanti"
+3 −0
Original line number Diff line number Diff line
@@ -48,3 +48,6 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] }
# If you're writing a test with this, take heed! `no-env-filter` means you'll be capturing
# logs from everything that speaks, so be specific with your asserts.
tracing-test = { version = "0.2.4", features = ["no-env-filter"] }

[dependencies]
pin-project-lite = "0.2.13"
+89 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Body wrappers must pass through size_hint

use aws_config::SdkConfig;
use aws_sdk_s3::{
    config::{Credentials, Region, SharedCredentialsProvider},
    primitives::{ByteStream, SdkBody},
    Client,
};
use aws_smithy_runtime::client::http::test_util::{capture_request, infallible_client_fn};
use http_body::Body;

#[tokio::test]
async fn download_body_size_hint_check() {
    let test_body_content = b"hello";
    let test_body = || SdkBody::from(&test_body_content[..]);
    assert_eq!(
        Some(test_body_content.len() as u64),
        (test_body)().size_hint().exact(),
        "pre-condition check"
    );

    let http_client = infallible_client_fn(move |_| {
        http::Response::builder()
            .status(200)
            .body((test_body)())
            .unwrap()
    });
    let sdk_config = SdkConfig::builder()
        .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
        .region(Region::new("us-east-1"))
        .http_client(http_client)
        .build();
    let client = Client::new(&sdk_config);
    let response = client
        .get_object()
        .bucket("foo")
        .key("foo")
        .send()
        .await
        .unwrap();
    assert_eq!(
        (
            test_body_content.len() as u64,
            Some(test_body_content.len() as u64),
        ),
        response.body.size_hint(),
        "the size hint should be passed through all the default body wrappers"
    );
}

#[tokio::test]
async fn upload_body_size_hint_check() {
    let test_body_content = b"hello";

    let (http_client, rx) = capture_request(None);
    let sdk_config = SdkConfig::builder()
        .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
        .region(Region::new("us-east-1"))
        .http_client(http_client)
        .build();
    let client = Client::new(&sdk_config);
    let body = ByteStream::from_static(test_body_content);
    assert_eq!(
        (
            test_body_content.len() as u64,
            Some(test_body_content.len() as u64),
        ),
        body.size_hint(),
        "pre-condition check"
    );
    let _response = client
        .put_object()
        .bucket("foo")
        .key("foo")
        .body(body)
        .send()
        .await;
    let captured_request = rx.expect_request();
    assert_eq!(
        Some(test_body_content.len() as u64),
        captured_request.body().size_hint().exact(),
        "the size hint should be passed through all the default body wrappers"
    );
}
+85 −36
Original line number Diff line number Diff line
@@ -4,27 +4,90 @@
 */

use aws_credential_types::Credentials;
use aws_sdk_s3::config::{Region, StalledStreamProtectionConfig};
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::{
    config::{Region, StalledStreamProtectionConfig},
    error::BoxError,
};
use aws_sdk_s3::{error::DisplayErrorContext, primitives::ByteStream};
use aws_sdk_s3::{Client, Config};
use bytes::BytesMut;
use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs};
use aws_smithy_types::body::SdkBody;
use bytes::{Bytes, BytesMut};
use http_body::Body;
use std::error::Error;
use std::future::Future;
use std::net::SocketAddr;
use std::time::Duration;
use std::{future::Future, task::Poll};
use std::{net::SocketAddr, pin::Pin, task::Context};
use tokio::{
    net::{TcpListener, TcpStream},
    time::sleep,
};
use tracing::debug;

// This test doesn't work because we can't count on `hyper` to poll the body,
// regardless of whether we schedule a wake. To make this functionality work,
// we'd have to integrate more closely with the orchestrator.
//
// I'll leave this test here because we do eventually want to support stalled
// stream protection for uploads.
#[ignore]
enum SlowBodyState {
    Wait(Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync + 'static>>),
    Send,
    Taken,
}

struct SlowBody {
    state: SlowBodyState,
}

impl SlowBody {
    fn new() -> Self {
        Self {
            state: SlowBodyState::Send,
        }
    }
}

impl Body for SlowBody {
    type Data = Bytes;
    type Error = BoxError;

    fn poll_data(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
        loop {
            let mut state = SlowBodyState::Taken;
            std::mem::swap(&mut state, &mut self.state);
            match state {
                SlowBodyState::Wait(mut fut) => match fut.as_mut().poll(cx) {
                    Poll::Ready(_) => self.state = SlowBodyState::Send,
                    Poll::Pending => {
                        self.state = SlowBodyState::Wait(fut);
                        return Poll::Pending;
                    }
                },
                SlowBodyState::Send => {
                    self.state = SlowBodyState::Wait(Box::pin(sleep(Duration::from_micros(100))));
                    return Poll::Ready(Some(Ok(Bytes::from_static(
                        b"data_data_data_data_data_data_data_data_data_data_data_data_\
                          data_data_data_data_data_data_data_data_data_data_data_data_\
                          data_data_data_data_data_data_data_data_data_data_data_data_\
                          data_data_data_data_data_data_data_data_data_data_data_data_",
                    ))));
                }
                SlowBodyState::Taken => unreachable!(),
            }
        }
    }

    fn poll_trailers(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
        Poll::Ready(Ok(None))
    }
}

#[tokio::test]
async fn test_stalled_stream_protection_defaults_for_upload() {
    // We spawn a faulty server that will close the connection after
    // writing half of the response body.
    let _logs = capture_test_logs();

    // We spawn a faulty server that will stop all request processing after reading half of the request body.
    let (server, server_addr) = start_faulty_upload_server().await;
    let _ = tokio::spawn(server);

@@ -32,7 +95,8 @@ async fn test_stalled_stream_protection_defaults_for_upload() {
        .credentials_provider(Credentials::for_tests())
        .region(Region::new("us-east-1"))
        .endpoint_url(format!("http://{server_addr}"))
        // .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
        // TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): make stalled stream protection enabled by default with BMV and remove this line
        .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build())
        .build();
    let client = Client::from_conf(conf);

@@ -40,22 +104,19 @@ async fn test_stalled_stream_protection_defaults_for_upload() {
        .put_object()
        .bucket("a-test-bucket")
        .key("stalled-stream-test.txt")
        .body(ByteStream::from_static(b"Hello"))
        .body(ByteStream::new(SdkBody::from_body_0_4(SlowBody::new())))
        .send()
        .await
        .expect_err("upload stream stalled out");

    let err = err.source().expect("inner error exists");
    assert_eq!(
        err.to_string(),
    let err_msg = DisplayErrorContext(&err).to_string();
    assert_str_contains!(
        err_msg,
        "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed"
    );
}

async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr) {
    use tokio::net::{TcpListener, TcpStream};
    use tokio::time::sleep;

    let listener = TcpListener::bind("0.0.0.0:0")
        .await
        .expect("socket is free");
@@ -65,12 +126,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
        let mut buf = BytesMut::new();
        let mut time_to_stall = false;

        loop {
            if time_to_stall {
                debug!("faulty server has read partial request, now getting stuck");
                break;
            }

        while !time_to_stall {
            match socket.try_read_buf(&mut buf) {
                Ok(0) => {
                    unreachable!(
@@ -79,12 +135,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
                }
                Ok(n) => {
                    debug!("read {n} bytes from the socket");

                    // Check to see if we've received some headers
                    if buf.len() >= 128 {
                        let s = String::from_utf8_lossy(&buf);
                        debug!("{s}");

                        time_to_stall = true;
                    }
                }
@@ -98,6 +149,7 @@ async fn start_faulty_upload_server() -> (impl Future<Output = ()>, SocketAddr)
            }
        }

        debug!("faulty server has read partial request, now getting stuck");
        loop {
            tokio::task::yield_now().await
        }
@@ -240,9 +292,6 @@ async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() {
}

async fn start_faulty_download_server() -> (impl Future<Output = ()>, SocketAddr) {
    use tokio::net::{TcpListener, TcpStream};
    use tokio::time::sleep;

    let listener = TcpListener::bind("0.0.0.0:0")
        .await
        .expect("socket is free");
+1 −4
Original line number Diff line number Diff line
@@ -120,15 +120,12 @@ class StalledStreamProtectionOperationCustomization(
                is OperationSection.AdditionalInterceptors -> {
                    val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection")
                    section.registerInterceptor(rc, this) {
                        // Currently, only response bodies are protected/supported because
                        // we can't count on hyper to poll a request body on wake.
                        rustTemplate(
                            """
                            #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody)
                            #{StalledStreamProtectionInterceptor}::default()
                            """,
                            *preludeScope,
                            "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"),
                            "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"),
                        )
                    }
                }
Loading