Unverified Commit dda3c098 authored by Liam Perlaki's avatar Liam Perlaki Committed by GitHub
Browse files

keep CompleteMultipartUpload alive (#141)



* keep CompleteMultipartUpload alive

* use s3_resp.headers in CompleteMultipartUpload

* propagate error in keepalive body

* unify serialize_error

* set transfer-encoding and trailer header

* add keep_alive_body unit tests

---------

Co-authored-by: default avatarNugine <nugine@foxmail.com>
parent 7b949d0c
Loading
Loading
Loading
Loading
+40 −16
Original line number Diff line number Diff line
@@ -294,7 +294,11 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {
                }

                if is_xml_output(ty) {
                    if op.name == "CompleteMultipartUpload" {
                        g!("http::set_xml_body_no_decl(&mut res, &x)?;");
                    } else {
                        g!("http::set_xml_body(&mut res, &x)?;");
                    }
                } else if let Some(field) = ty.fields.iter().find(|x| x.position == "payload") {
                    match field.type_.as_str() {
                        "Policy" => {
@@ -652,12 +656,31 @@ fn codegen_op_http_call(op: &Operation) {
        g!("let overrided_headers = super::get_object::extract_overrided_response_headers(&s3_req)?;");
    }

    if op.name == "CompleteMultipartUpload" {
        g!("let s3 = s3.clone();");
        g!("let fut = async move {{");
        g!("let result = s3.{method}(s3_req).await;");
        g!("match result {{");
        glines![
            "Ok(s3_resp) => {
                let mut resp = Self::serialize_http(s3_resp.output)?;
                resp.headers.extend(s3_resp.headers);
                Ok(resp)
            }"
        ];
        g!("Err(err) => super::serialize_error(err, true).map_err(Into::into),");
        g!("}}");
        g!("}};");
        g!("let mut resp = http::Response::with_status(http::StatusCode::OK);");
        g!("http::set_keep_alive_xml_body(&mut resp, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;");
        g!("http::add_opt_header(&mut resp, \"trailer\", Some([X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED.as_str(), X_AMZ_EXPIRATION.as_str(), X_AMZ_REQUEST_CHARGED.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION.as_str(), X_AMZ_VERSION_ID.as_str()].join(\",\")))?;");
    } else {
        g!("let result = s3.{method}(s3_req).await;");

        glines![
            "let s3_resp = match result {"
            "    Ok(val) => val,"
        "    Err(err) => return super::serialize_error(err),"
            "    Err(err) => return super::serialize_error(err, false),"
            "};"
        ];

@@ -671,6 +694,7 @@ fn codegen_op_http_call(op: &Operation) {
        }

        g!("resp.extensions.extend(s3_resp.extensions);");
    }
    g!("Ok(resp)");

    g!("}}");
+3 −0
Original line number Diff line number Diff line
@@ -57,5 +57,8 @@ transform-stream = "0.3.0"
urlencoding = "2.1.3"
zeroize = "1.6.0"

sync_wrapper = { version = "1.0.0", default-features = false }
tokio = { version = "1.31.0", features = ["time"] }

[dev-dependencies]
tokio = { version = "1.31.0", features = ["full"] }
+30 −0
Original line number Diff line number Diff line
@@ -5,8 +5,10 @@ use crate::dto::SelectObjectContentEventStream;
use crate::dto::{Metadata, StreamingBlob, Timestamp, TimestampFormat};
use crate::error::{S3Error, S3Result};
use crate::http::{HeaderName, HeaderValue};
use crate::keep_alive_body::KeepAliveBody;
use crate::utils::format::fmt_timestamp;
use crate::xml;
use crate::StdError;

use std::convert::Infallible;
use std::fmt::Write as _;
@@ -105,6 +107,34 @@ pub fn set_xml_body<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result
    Ok(())
}

#[allow(clippy::declare_interior_mutable_const)]
const TRANSFER_ENCODING_CHUNKED: HeaderValue = HeaderValue::from_static("chunked");

pub fn set_keep_alive_xml_body(
    res: &mut Response,
    fut: impl std::future::Future<Output = Result<Response, StdError>> + Send + Sync + 'static,
    duration: std::time::Duration,
) -> S3Result {
    let mut buf = Vec::with_capacity(40);
    let mut ser = xml::Serializer::new(&mut buf);
    ser.decl().map_err(S3Error::internal_error)?;

    res.body = Body::http_body(KeepAliveBody::with_initial_body(fut, buf.into(), duration));
    res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
    res.headers
        .insert(hyper::header::TRANSFER_ENCODING, TRANSFER_ENCODING_CHUNKED);
    Ok(())
}

pub fn set_xml_body_no_decl<T: xml::Serialize>(res: &mut Response, val: &T) -> S3Result {
    let mut buf = Vec::with_capacity(256);
    let mut ser = xml::Serializer::new(&mut buf);
    val.serialize(&mut ser).map_err(S3Error::internal_error)?;
    res.body = Body::from(buf);
    res.headers.insert(hyper::header::CONTENT_TYPE, APPLICATION_XML);
    Ok(())
}

pub fn set_stream_body(res: &mut Response, stream: StreamingBlob) {
    res.body = Body::from(stream);
}
+163 −0
Original line number Diff line number Diff line
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
    time::Duration,
};

use bytes::Bytes;
use http_body::{Body, Frame};
use tokio::time::Interval;

use crate::{http::Response, StdError};

// sends whitespace while the future is pending
pin_project_lite::pin_project! {

    pub struct KeepAliveBody<F> {
        #[pin]
        inner: F,
        initial_body: Option<Bytes>,
        response: Option<Response>,
        interval: Interval,
        done: bool,
    }
}
impl<F> KeepAliveBody<F> {
    pub fn new(inner: F, interval: Duration) -> Self {
        Self {
            inner,
            initial_body: None,
            response: None,
            interval: tokio::time::interval(interval),
            done: false,
        }
    }

    pub fn with_initial_body(inner: F, initial_body: Bytes, interval: Duration) -> Self {
        Self {
            inner,
            initial_body: Some(initial_body),
            response: None,
            interval: tokio::time::interval(interval),
            done: false,
        }
    }
}

impl<F> Body for KeepAliveBody<F>
where
    F: Future<Output = Result<Response, StdError>>,
{
    type Data = Bytes;

    type Error = StdError;

    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        if self.done {
            return Poll::Ready(None);
        }
        let mut this = self.project();
        if let Some(initial_body) = this.initial_body.take() {
            cx.waker().wake_by_ref();
            return Poll::Ready(Some(Ok(Frame::data(initial_body))));
        }
        loop {
            if let Some(response) = &mut this.response {
                let frame = std::task::ready!(Pin::new(&mut response.body).poll_frame(cx)?);
                if let Some(frame) = frame {
                    return Poll::Ready(Some(Ok(frame)));
                }
                *this.done = true;
                return Poll::Ready(Some(Ok(Frame::trailers(std::mem::take(&mut response.headers)))));
            }
            match this.inner.as_mut().poll(cx) {
                Poll::Ready(response) => match response {
                    Ok(response) => {
                        *this.response = Some(response);
                    }
                    Err(e) => {
                        *this.done = true;
                        return Poll::Ready(Some(Err(e)));
                    }
                },
                Poll::Pending => match this.interval.poll_tick(cx) {
                    Poll::Ready(_) => return Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b" "))))),
                    Poll::Pending => return Poll::Pending,
                },
            }
        }
    }

    fn is_end_stream(&self) -> bool {
        self.done
    }
}

#[cfg(test)]
mod tests {
    use http_body_util::BodyExt;
    use hyper::{header::HeaderValue, StatusCode};

    use super::*;

    #[tokio::test]
    async fn keep_alive_body() {
        let body = KeepAliveBody::with_initial_body(
            async {
                let mut res = Response::with_status(StatusCode::OK);
                res.body = Bytes::from_static(b" world").into();
                res.headers.insert("key", HeaderValue::from_static("value"));
                Ok(res)
            },
            Bytes::from_static(b"hello"),
            Duration::from_secs(1),
        );

        let aggregated = body.collect().await.unwrap();

        assert_eq!(aggregated.trailers().unwrap().get("key").unwrap(), "value");

        let buf = aggregated.to_bytes();

        assert_eq!(buf, b"hello world".as_slice());
    }

    #[tokio::test]
    async fn keep_alive_body_no_initial() {
        let body = KeepAliveBody::new(
            async {
                let mut res = Response::with_status(StatusCode::OK);
                res.body = Bytes::from_static(b"hello world").into();
                Ok(res)
            },
            Duration::from_secs(1),
        );

        let aggregated = body.collect().await.unwrap();

        let buf = aggregated.to_bytes();

        assert_eq!(buf, b"hello world".as_slice());
    }

    #[tokio::test]
    async fn keep_alive_body_fill_withespace() {
        let body = KeepAliveBody::new(
            async {
                tokio::time::sleep(Duration::from_millis(50)).await;

                let mut res = Response::with_status(StatusCode::OK);
                res.body = Bytes::from_static(b"hello world").into();
                Ok(res)
            },
            Duration::from_millis(10),
        );

        let aggregated = body.collect().await.unwrap();

        let buf = aggregated.to_bytes();

        assert_eq!(buf, b"     hello world".as_slice());
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ pub mod service;
pub mod stream;

pub mod checksum;
pub mod keep_alive_body;

pub use self::error::*;
pub use self::http::Body;
Loading