Unverified Commit 1e124784 authored by Nugine's avatar Nugine
Browse files

s3s: ops: decoded-content-length

parent ba0a4eb4
Loading
Loading
Loading
Loading
+39 −7
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ use transform_stream::AsyncTryStream;
pub struct AwsChunkedStream {
    /// inner
    inner: AsyncTryStream<Bytes, AwsChunkedStreamError, BoxFuture<'static, Result<(), AwsChunkedStreamError>>>,

    remaining_length: usize,
}

impl Debug for AwsChunkedStream {
@@ -107,7 +109,14 @@ fn check_signature(ctx: &SignatureCtx, expected_signature: &[u8], chunk_data: &[

impl AwsChunkedStream {
    /// Constructs a `ChunkedStream`
    pub fn new<S>(body: S, seed_signature: Box<str>, amz_date: AmzDate, region: Box<str>, secret_key: Box<str>) -> Self
    pub fn new<S>(
        body: S,
        seed_signature: Box<str>,
        amz_date: AmzDate,
        region: Box<str>,
        secret_key: Box<str>,
        decoded_content_length: usize,
    ) -> Self
    where
        S: Stream<Item = io::Result<Bytes>> + Send + 'static,
    {
@@ -162,7 +171,10 @@ impl AwsChunkedStream {
                Ok(())
            })
        });
        Self { inner }
        Self {
            inner,
            remaining_length: decoded_content_length,
        }
    }

    /// read meta bytes and return remaining bytes
@@ -266,14 +278,26 @@ impl AwsChunkedStream {

        Some(Ok((bytes_buffer, remaining_bytes)))
    }

    pub fn remaining_length(&self) -> usize {
        self.remaining_length
    }

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, AwsChunkedStreamError>>> {
        let ans = Pin::new(&mut self.inner).poll_next(cx);
        if let Poll::Ready(Some(Ok(ref bytes))) = ans {
            self.remaining_length = self.remaining_length.saturating_sub(bytes.len());
        }
        ans
    }
}

#[allow(clippy::missing_trait_methods)]
impl Stream for AwsChunkedStream {
    type Item = Result<Bytes, AwsChunkedStreamError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        Pin::new(&mut self.inner).poll_next(cx)
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.poll(cx)
    }
}

@@ -297,10 +321,12 @@ mod tests {

        let chunk1_data = vec![b'a'; 0x10000]; // 65536
        let chunk2_data = vec![b'a'; 1024];
        let chunk3_data = [];
        let decoded_content_length = chunk1_data.len() + chunk2_data.len() + chunk3_data.len();

        let chunk1 = join(&[chunk1_meta, &chunk1_data, b"\r\n"]);
        let chunk2 = join(&[chunk2_meta, &chunk2_data, b"\r\n"]);
        let chunk3 = join(&[chunk3_meta, b"\r\n"]);
        let chunk3 = join(&[chunk3_meta, &chunk3_data, b"\r\n"]);

        let chunk_results: Vec<Result<Bytes, _>> = vec![Ok(chunk1), Ok(chunk2), Ok(chunk3)];

@@ -312,8 +338,14 @@ mod tests {
        let date = AmzDate::parse(timestamp).unwrap();

        let stream = futures::stream::iter(chunk_results.into_iter());
        let mut chunked_stream =
            AwsChunkedStream::new(stream, seed_signature.into(), date, region.into(), secret_access_key.into());
        let mut chunked_stream = AwsChunkedStream::new(
            stream,
            seed_signature.into(),
            date,
            region.into(),
            secret_access_key.into(),
            decoded_content_length,
        );

        let ans1 = chunked_stream.next().await.unwrap();
        assert_eq!(ans1.unwrap(), chunk1_data.as_slice());
+22 −8
Original line number Diff line number Diff line
@@ -123,6 +123,15 @@ fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
        Err(e) => Err(invalid_request!(e, "invalid header: x-amz-date")),
    }
}

fn extract_decoded_content_length(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<usize>> {
    let Some(val) = hs.get(crate::header::names::X_AMZ_DECODED_CONTENT_LENGTH) else { return Ok(None) };
    match atoi::atoi::<usize>(val.as_bytes()) {
        Some(x) => Ok(Some(x)),
        None => Err(invalid_request!("invalid header: x-amz-decoded-content-length")),
    }
}

pub async fn call(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<Response> {
    match call_inner(req, s3, auth, base_domain).await {
        Ok(res) => Ok(res),
@@ -140,6 +149,7 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
        let mut body = mem::take(req.body_mut());
        let headers = extract_headers(req)?;
        let mime = extract_mime(&headers)?;
        let decoded_content_length = extract_decoded_content_length(&headers)?;
        let body_transformed;
        {
            let mut scx = SignatureContext {
@@ -151,6 +161,7 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
                body,
                multipart: None,
                body_transformed: false,
                decoded_content_length,
            };

            scx.check().await?;
@@ -161,21 +172,16 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
        }
        *req.body_mut() = body;
        if body_transformed {
            let len = req
                .headers()
                .get(header::names::X_AMZ_DECODED_CONTENT_LENGTH)
                .and_then(|val| atoi::atoi::<u64>(val.as_bytes()))
                .unwrap_or(0);

            if let Some(val) = req.headers_mut().get_mut(header::names::CONTENT_LENGTH) {
                let len = decoded_content_length.unwrap_or(0);
                if len > 0 {
                    *val = crate::utils::fmt_u64(len, |s| http::HeaderValue::try_from(s).unwrap())
                    *val = crate::utils::fmt_usize(len, |s| http::HeaderValue::try_from(s).unwrap())
                } else {
                    *val = http::HeaderValue::from_static("0");
                }
            }
        }
        debug!(?body_transformed);
        debug!(?body_transformed, ?decoded_content_length);
    }

    let (op, needs_full_body) = 'resolve: {
@@ -219,6 +225,7 @@ struct SignatureContext<'a> {
    body: Body,
    multipart: Option<Multipart>,
    body_transformed: bool,
    decoded_content_length: Option<usize>,
}

impl SignatureContext<'_> {
@@ -397,6 +404,10 @@ impl SignatureContext<'_> {
        }

        if is_stream {
            let decoded_content_length = self
                .decoded_content_length
                .ok_or_else(|| s3_error!(MissingContentLength, "missing header: x-amz-decoded-content-length"))?;

            let body = io_stream(mem::take(&mut self.body));

            let chunked_stream = AwsChunkedStream::new(
@@ -405,8 +416,11 @@ impl SignatureContext<'_> {
                amz_date,
                authorization.credential.aws_region.into(),
                secret_key.into(),
                decoded_content_length,
            );

            debug!(len=?chunked_stream.remaining_length(), "aws-chunked");

            self.body = Body::wrap_stream(chunked_stream);
            self.body_transformed = true;
        }
+1 −1
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ pub fn fmt_long<T>(val: i64, f: impl FnOnce(&str) -> T) -> T {
    f(buf.as_str())
}

pub fn fmt_u64<T>(val: u64, f: impl FnOnce(&str) -> T) -> T {
pub fn fmt_usize<T>(val: usize, f: impl FnOnce(&str) -> T) -> T {
    let mut buf = ArrayString::<32>::new();
    write!(&mut buf, "{val}").unwrap();
    f(buf.as_str())