Unverified Commit dd6fac19 authored by Nugine's avatar Nugine
Browse files

s3s: stream

parent 5b7efedb
Loading
Loading
Loading
Loading
+16 −11
Original line number Diff line number Diff line
@@ -494,20 +494,21 @@ fn codegen_op_http_de_multipart(op: &Operation, rust_types: &RustTypes, g: &mut
    assert_eq!(op.name, "PutObject");

    g.ln(f!(
        "pub fn deserialize_http_multipart(req: &mut http::Request, mut m: http::Multipart) -> S3Result<{}> {{",
        "pub fn deserialize_http_multipart(req: &mut http::Request, m: http::Multipart) -> S3Result<{}> {{",
        op.input
    ));

    {
        g.ln("let bucket = http::unwrap_bucket(req);");
        g.ln("let key = http::parse_field_value(&m, \"key\")?.ok_or_else(|| invalid_request!(\"missing key\"))?;");
        g.lf();
    }

    {
        g.ln("let body: Option<StreamingBlob> = m.take_file_stream().map(StreamingBlob::wrap);");
        g.lf();
    }
    g.lines([
        "let bucket = http::unwrap_bucket(req);",
        "let key = http::parse_field_value(&m, \"key\")?.ok_or_else(|| invalid_request!(\"missing key\"))?;",
        "",
        "let vec_stream = req.extensions_mut().remove::<crate::stream::VecByteStream>().expect(\"missing vec stream\");",
        "",
        "let content_length = i64::try_from(vec_stream.exact_remaining_length()).map_err(|e|s3_error!(e, InvalidArgument, \"content-length overflow\"))?;",
        "",
        "let body: Option<StreamingBlob> = Some(StreamingBlob::new(vec_stream));",
        "",
    ]);

    let rust::Type::Struct(ty) = &rust_types[op.input.as_str()] else { panic!() };

@@ -519,6 +520,10 @@ fn codegen_op_http_de_multipart(op: &Operation, rust_types: &RustTypes, g: &mut
                assert!(header.as_bytes().iter().all(|&x| x == b'-' || x.is_ascii_alphanumeric()));
                let header = header.to_ascii_lowercase();

                if header == "content-length" {
                    continue;
                }

                let field_type = &rust_types[field.type_.as_str()];

                if let rust::Type::Timestamp(ts_ty) = field_type {
+46 −0
Original line number Diff line number Diff line
@@ -68,6 +68,12 @@ impl From<StreamingBlob> for DynByteStream {
    }
}

impl From<DynByteStream> for StreamingBlob {
    fn from(value: DynByteStream) -> Self {
        Self { inner: value }
    }
}

impl From<StreamingBlob> for Body {
    fn from(value: StreamingBlob) -> Self {
        Body::from(value.into_inner())
@@ -79,3 +85,43 @@ impl From<Body> for StreamingBlob {
        Self::new(value)
    }
}

pin_project_lite::pin_project! {
    pub(crate) struct StreamWrapper<S> {
        #[pin]
        inner: S
    }
}

impl<S, E> Stream for StreamWrapper<S>
where
    S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static,
    E: std::error::Error + Send + Sync + 'static,
{
    type Item = Result<Bytes, StdError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();
        this.inner.poll_next(cx).map_err(|e| Box::new(e) as StdError)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.inner.size_hint()
    }
}

impl<S> ByteStream for StreamWrapper<S>
where
    StreamWrapper<S>: Stream<Item = Result<Bytes, StdError>>,
{
    fn remaining_length(&self) -> RemainingLength {
        RemainingLength::unknown()
    }
}

fn wrap<S>(inner: S) -> DynByteStream
where
    StreamWrapper<S>: ByteStream<Item = Result<Bytes, StdError>> + Send + Sync + 'static,
{
    Box::pin(StreamWrapper { inner })
}
+14 −7
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
use crate::error::StdError;
use crate::header::AmzDate;
use crate::signature_v4;
use crate::stream::{ByteStream, RemainingLength};
use crate::stream::{ByteStream, DynByteStream, RemainingLength};
use crate::utils::SyncBoxFuture;

use std::convert::TryInto;
@@ -280,10 +280,6 @@ impl AwsChunkedStream {
        Some(Ok((bytes_buffer, remaining_bytes)))
    }

    pub fn remaining_length(&self) -> usize {
        self.remaining_length
    }

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, AwsChunkedStreamError>>> {
        let ans = Pin::new(&mut self.inner).poll_next(cx);
        if let Poll::Ready(Some(Ok(ref bytes))) = ans {
@@ -291,20 +287,31 @@ impl AwsChunkedStream {
        }
        ans
    }

    pub fn exact_remaining_length(&self) -> usize {
        self.remaining_length
    }

    pub fn into_byte_stream(self) -> DynByteStream {
        crate::stream::into_dyn(self)
    }
}

#[allow(clippy::missing_trait_methods)]
impl Stream for AwsChunkedStream {
    type Item = Result<Bytes, AwsChunkedStreamError>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.poll(cx)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, None)
    }
}

impl ByteStream for AwsChunkedStream {
    fn remaining_length(&self) -> RemainingLength {
        RemainingLength::new_exact(self.remaining_length())
        RemainingLength::new_exact(self.remaining_length)
    }
}

+11 −1
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@
//!

use crate::error::StdError;
use crate::stream::ByteStream;
use crate::utils::SyncBoxFuture;

use std::fmt::{self, Debug};
@@ -342,13 +343,22 @@ impl FileStream {
    }
}

#[allow(clippy::missing_trait_methods)]
impl Stream for FileStream {
    type Item = Result<Bytes, FileStreamError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
        Pin::new(&mut self.inner).poll_next(cx)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        (0, None)
    }
}

impl ByteStream for FileStream {
    fn remaining_length(&self) -> crate::stream::RemainingLength {
        crate::stream::RemainingLength::unknown()
    }
}

/// CRLF lines
+10 −4
Original line number Diff line number Diff line
@@ -4377,11 +4377,19 @@ impl PutObject {
        })
    }

    pub fn deserialize_http_multipart(req: &mut http::Request, mut m: http::Multipart) -> S3Result<PutObjectInput> {
    pub fn deserialize_http_multipart(req: &mut http::Request, m: http::Multipart) -> S3Result<PutObjectInput> {
        let bucket = http::unwrap_bucket(req);
        let key = http::parse_field_value(&m, "key")?.ok_or_else(|| invalid_request!("missing key"))?;

        let body: Option<StreamingBlob> = m.take_file_stream().map(StreamingBlob::wrap);
        let vec_stream = req
            .extensions_mut()
            .remove::<crate::stream::VecByteStream>()
            .expect("missing vec stream");

        let content_length = i64::try_from(vec_stream.exact_remaining_length())
            .map_err(|e| s3_error!(e, InvalidArgument, "content-length overflow"))?;

        let body: Option<StreamingBlob> = Some(StreamingBlob::new(vec_stream));

        let acl: Option<ObjectCannedACL> = http::parse_field_value(&m, "x-amz-acl")?;

@@ -4406,8 +4414,6 @@ impl PutObject {

        let content_language: Option<ContentLanguage> = http::parse_field_value(&m, "content-language")?;

        let content_length: ContentLength = http::parse_field_value(&m, "content-length")?.unwrap_or(0);

        let content_md5: Option<ContentMD5> = http::parse_field_value(&m, "content-md5")?;

        let content_type: Option<ContentType> = http::parse_field_value(&m, "content-type")?;
Loading