Commit 59d6fd97 authored by Nugine's avatar Nugine
Browse files

feat(s3s): support streaming trailers

parent 6bf36f92
Loading
Loading
Loading
Loading
+395 −13
Original line number Diff line number Diff line
@@ -2,13 +2,18 @@

use crate::auth::SecretKey;
use crate::error::StdError;
use crate::protocol::TrailingHeaders;
use crate::sig_v4;
use crate::sig_v4::AmzDate;
use crate::sig_v4::create_trailer_string_to_sign;
use crate::stream::{ByteStream, DynByteStream, RemainingLength};
use crate::utils::SyncBoxFuture;

use hyper::HeaderMap;
use hyper::http::{HeaderName, HeaderValue};
use std::fmt::{self, Debug};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

use futures::pin_mut;
@@ -23,6 +28,9 @@ pub struct AwsChunkedStream {
    inner: AsyncTryStream<Bytes, AwsChunkedStreamError, SyncBoxFuture<'static, Result<(), AwsChunkedStreamError>>>,

    remaining_length: usize,

    // Parsed trailing headers (lower-cased names) if present and verified.
    trailers: Arc<Mutex<Option<HeaderMap<HeaderValue>>>>,
}

impl Debug for AwsChunkedStream {
@@ -68,17 +76,19 @@ pub enum AwsChunkedStreamError {
}

/// Chunk meta
#[derive(Debug)]
struct ChunkMeta<'a> {
    /// chunk size
    size: usize,
    /// chunk signature
    signature: &'a [u8],
    /// Optional chunk signature.
    /// `Some` for signed chunks,
    /// `None` for unsigned streaming
    signature: Option<&'a [u8]>,
}

/// nom parser
fn parse_chunk_meta(mut input: &[u8]) -> nom::IResult<&[u8], ChunkMeta<'_>> {
    use crate::utils::parser::consume;

    use nom::bytes::complete::{tag, take, take_till1};
    use nom::combinator::{all_consuming, map_res};
    use nom::number::complete::hex_u32;
@@ -86,10 +96,19 @@ fn parse_chunk_meta(mut input: &[u8]) -> nom::IResult<&[u8], ChunkMeta<'_>> {

    let s = &mut input;

    let size = consume(s, take_till1(|c| c == b';'))?;
    // read size until ';' or CR
    let size = consume(s, take_till1(|c| c == b';' || c == b'\r'))?;
    let (_, size) = map_res(hex_u32, TryInto::try_into)(size)?;

    let signature = consume(s, all_consuming(delimited(tag(b";chunk-signature="), take(64_usize), tag(b"\r\n"))))?;
    // Either ";chunk-signature=<64>\r\n" or just "\r\n"
    let signature = if let Ok(sig) = consume(s, all_consuming(delimited(tag(b";chunk-signature="), take(64_usize), tag(b"\r\n"))))
    {
        Some(sig)
    } else {
        // If no signature extension, accept plain CRLF
        let _ = consume(s, all_consuming(tag(b"\r\n")))?;
        None
    };

    Ok((input, ChunkMeta { size, signature }))
}
@@ -106,6 +125,7 @@ fn check_signature(ctx: &SignatureCtx, expected_signature: &[u8], chunk_data: &[

impl AwsChunkedStream {
    /// Constructs a `ChunkedStream`
    #[allow(clippy::too_many_arguments)]
    pub fn new<S>(
        body: S,
        seed_signature: Box<str>,
@@ -114,10 +134,13 @@ impl AwsChunkedStream {
        service: Box<str>,
        secret_key: SecretKey,
        decoded_content_length: usize,
        unsigned: bool,
    ) -> Self
    where
        S: Stream<Item = Result<Bytes, StdError>> + Send + Sync + 'static,
    {
        let trailers: Arc<Mutex<Option<HeaderMap<HeaderValue>>>> = Arc::new(Mutex::new(None));
        let trailers_for_worker = Arc::clone(&trailers);
        let inner = AsyncTryStream::<_, _, SyncBoxFuture<'static, Result<(), AwsChunkedStreamError>>>::new(|mut y| {
            #[allow(clippy::shadow_same)] // necessary for `pin_mut!`
            Box::pin(async move {
@@ -140,31 +163,71 @@ impl AwsChunkedStream {
                            Some(Ok(remaining_bytes)) => prev_bytes = remaining_bytes,
                        }
                        if let Ok((_, meta)) = parse_chunk_meta(&buf) {
                            // Enforce signature presence based on mode
                            if !unsigned && meta.signature.is_none() {
                                return Err(AwsChunkedStreamError::FormatError);
                            }
                            meta
                        } else {
                            return Err(AwsChunkedStreamError::FormatError);
                        }
                    };

                    let data: Vec<Bytes> = {
                    tracing::trace!(?meta);

                    let (data, remaining_after_data): (Vec<Bytes>, Bytes) = {
                        match Self::read_data(body.as_mut(), prev_bytes, meta.size).await {
                            None => return Err(AwsChunkedStreamError::Incomplete),
                            Some(Err(e)) => return Err(e),
                            Some(Ok((data, remaining_bytes))) => {
                                prev_bytes = remaining_bytes;
                                data
                            }
                            Some(Ok((data, remaining_bytes))) => (data, remaining_bytes),
                        }
                    };

                    match check_signature(&ctx, meta.signature, &data) {
                    if let Some(expected_sig) = meta.signature {
                        match check_signature(&ctx, expected_sig, &data) {
                            None => return Err(AwsChunkedStreamError::SignatureMismatch),
                            Some(signature) => ctx.prev_signature = signature,
                        }
                    }

                    for bytes in data {
                        y.yield_ok(bytes).await;
                    }

                    if meta.size == 0 {
                        // Try to read all remaining bytes (including any immediate remainder) as trailing headers
                        // block and verify signature if present.
                        match Self::read_and_verify_trailers(body.as_mut(), remaining_after_data, &ctx, unsigned).await {
                            // No more bytes => no trailers, treat as normal termination.
                            None => break,
                            Some(Err(e)) => return Err(e),
                            Some(Ok((entries, _remaining_bytes))) => {
                                // After trailers there must not be any additional payload bytes in a valid request.
                                // However, if there are remaining bytes, we will just attempt to continue parsing
                                // which will likely fail with FormatError.
                                // Build HeaderMap from entries and store.
                                let mut map: HeaderMap<HeaderValue> = HeaderMap::new();
                                for (name, value) in entries {
                                    // Names are already lower-cased ASCII
                                    let hn: HeaderName = match name.parse() {
                                        Ok(h) => h,
                                        Err(_) => return Err(AwsChunkedStreamError::FormatError),
                                    };
                                    let hv: HeaderValue = match HeaderValue::from_str(&value) {
                                        Ok(v) => v,
                                        Err(_) => return Err(AwsChunkedStreamError::FormatError),
                                    };
                                    map.append(hn, hv);
                                }
                                tracing::debug!(trailers=?map);
                                *trailers_for_worker.lock().unwrap() = Some(map);
                                break;
                            }
                        }
                    }

                    // Carry over any remaining bytes to next iteration when there are more chunks.
                    prev_bytes = remaining_after_data;
                }

                Ok(())
@@ -173,7 +236,131 @@ impl AwsChunkedStream {
        Self {
            inner,
            remaining_length: decoded_content_length,
            trailers,
        }
    }

    /// Read trailing headers after the final 0-size chunk and verify trailer signature.
    /// Returns:
    /// - `None`: if there are no more bytes (no trailers present)
    /// - `Some(Ok(remaining_bytes))`: when trailers are consumed; usually `remaining_bytes` is empty
    /// - `Some(Err(..))`: on errors
    async fn read_and_verify_trailers<S>(
        mut body: Pin<&mut S>,
        prev_bytes: Bytes,
        ctx: &SignatureCtx,
        unsigned: bool,
    ) -> Option<Result<(Vec<(String, String)>, Bytes), AwsChunkedStreamError>>
    where
        S: Stream<Item = Result<Bytes, StdError>> + Send + 'static,
    {
        // Accumulate all remaining bytes until EOF.
        let mut buf: Vec<u8> = Vec::new();

        if !prev_bytes.is_empty() {
            buf.extend_from_slice(prev_bytes.as_ref());
        }

        // Read to end
        while let Some(next) = body.next().await {
            match next {
                Err(e) => return Some(Err(AwsChunkedStreamError::Underlying(e))),
                Ok(bytes) => buf.extend_from_slice(bytes.as_ref()),
            }
        }

        if buf.is_empty() || buf == b"\r\n" {
            // No trailers present
            return None;
        }

        // Parse trailing headers lines. Lines are separated by CRLF or LF.
        // Build canonical_trailers by sorting headers (excluding x-amz-trailer-signature) and joining as `name:value\n`.
        // Extract the provided trailer signature from header `x-amz-trailer-signature`.
        let (canonical_trailers, provided_signature, entries) = match Self::parse_trailing_headers(&buf) {
            Ok(x) => x,
            Err(e) => return Some(Err(e)),
        };

        // Verify the trailer signature if present, or require it when unsigned=false
        if let Some(provided) = provided_signature.as_ref() {
            let string_to_sign =
                create_trailer_string_to_sign(&ctx.amz_date, &ctx.region, &ctx.service, &ctx.prev_signature, &canonical_trailers);
            let trailer_signature =
                sig_v4::calculate_signature(&string_to_sign, &ctx.secret_key, &ctx.amz_date, &ctx.region, &ctx.service);
            if trailer_signature.as_bytes() != provided.as_slice() {
                return Some(Err(AwsChunkedStreamError::SignatureMismatch));
            }
        } else if !unsigned {
            // In signed-with-trailer mode, missing trailer signature is an error
            return Some(Err(AwsChunkedStreamError::FormatError));
        }

        Some(Ok((entries, Bytes::new())))
    }

    #[allow(clippy::type_complexity)]
    fn parse_trailing_headers(buf: &[u8]) -> Result<(Vec<u8>, Option<Vec<u8>>, Vec<(String, String)>), AwsChunkedStreamError> {
        // Split into lines by `\n`. Accept optional `\r` before `\n` and also handle last line without `\n`.
        let mut entries: Vec<(String, String)> = Vec::new();
        let mut provided_signature: Option<Vec<u8>> = None;

        let mut start = 0usize;
        for i in 0..=buf.len() {
            let is_end = i == buf.len();
            let is_lf = !is_end && buf[i] == b'\n';
            if is_lf || is_end {
                let mut line = &buf[start..i];
                // trim trailing CR
                if let Some(&b'\r') = line.last() {
                    line = &line[..line.len().saturating_sub(1)];
                }
                // advance start
                start = i.saturating_add(1);

                if line.is_empty() {
                    continue;
                }

                // Find ':'
                let Some(colon_pos) = memchr(b':', line) else {
                    return Err(AwsChunkedStreamError::FormatError);
                };
                let (name_raw, value_raw) = line.split_at(colon_pos);
                // value_raw starts with ':'
                let value_raw = &value_raw[1..];

                // Lowercase name
                let name = String::from_utf8(name_raw.to_ascii_lowercase()).map_err(|_| AwsChunkedStreamError::FormatError)?;

                // Trim ASCII whitespace around value
                let value = {
                    let v = value_raw;
                    let v = trim_ascii_whitespace(v);
                    String::from_utf8(v.to_vec()).map_err(|_| AwsChunkedStreamError::FormatError)?
                };

                if name == "x-amz-trailer-signature" {
                    provided_signature = Some(value.into_bytes());
                } else {
                    entries.push((name, value));
                }
            }
        }

        // Sort by header name to canonicalize deterministically
        entries.sort_by(|a, b| a.0.cmp(&b.0));

        // Build canonical bytes: name:value\n
        let mut canonical: Vec<u8> = Vec::new();
        for (n, v) in &entries {
            canonical.extend_from_slice(n.as_bytes());
            canonical.push(b':');
            canonical.extend_from_slice(v.as_bytes());
            canonical.push(b'\n');
        }

        Ok((canonical, provided_signature, entries))
    }

    /// read meta bytes and return remaining bytes
@@ -220,6 +407,10 @@ impl AwsChunkedStream {
    where
        S: Stream<Item = Result<Bytes, StdError>> + Send + 'static,
    {
        if data_size == 0 {
            return Some(Ok((Vec::new(), prev_bytes)));
        }

        let mut bytes_buffer = Vec::new();
        let mut push_data_bytes = |mut bytes: Bytes| {
            if data_size == 0 {
@@ -294,6 +485,15 @@ impl AwsChunkedStream {
    pub fn into_byte_stream(self) -> DynByteStream {
        crate::stream::into_dyn(self)
    }

    // Note: Trailing headers should be accessed via trailing_headers_handle().

    /// Get a handle to access verified trailing headers later.
    /// This can be cloned and stored outside to retrieve trailers after the
    /// stream has been fully read.
    pub(crate) fn trailing_headers_handle(&self) -> TrailingHeaders {
        TrailingHeaders::from_shared(Arc::clone(&self.trailers))
    }
}

impl Stream for AwsChunkedStream {
@@ -314,6 +514,17 @@ impl ByteStream for AwsChunkedStream {
    }
}

#[inline]
fn trim_ascii_whitespace(mut s: &[u8]) -> &[u8] {
    while matches!(s.first(), Some(b' ' | b'\t')) {
        s = &s[1..];
    }
    while matches!(s.last(), Some(b' ' | b'\t')) {
        s = &s[..s.len() - 1];
    }
    s
}

#[cfg(test)]
mod tests {
    use super::*;
@@ -360,6 +571,7 @@ mod tests {
            service.into(),
            secret_access_key.into(),
            decoded_content_length,
            false,
        );

        let ans1 = chunked_stream.next().await.unwrap();
@@ -374,4 +586,174 @@ mod tests {
            assert!(chunked_stream.next().await.is_none());
        }
    }

    #[tokio::test]
    async fn example_put_object_chunked_stream_with_trailers() {
        // Example from AWS docs: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming-trailers.html
        // Seed signature corresponds to STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER canonical request
        let chunk1_meta = b"10000;chunk-signature=b474d8862b1487a5145d686f57f013e54db672cee1c953b3010fb58501ef5aa2\r\n";
        let chunk2_meta = b"400;chunk-signature=1c1344b170168f8e65b41376b44b20fe354e373826ccbbe2c1d40a8cae51e5c7\r\n";
        let chunk3_meta = b"0;chunk-signature=2ca2aba2005185cf7159c6277faf83795951dd77a3a99e6e65d5c9f85863f992\r\n";

        let chunk1_data = vec![b'a'; 0x10000]; // 65536
        let chunk2_data = vec![b'a'; 1024];
        let chunk3_data = [];
        let decoded_content_length = chunk1_data.len() + chunk2_data.len() + chunk3_data.len();

        let chunk1 = join(&[chunk1_meta, &chunk1_data, b"\r\n"]);
        let chunk2 = join(&[chunk2_meta, &chunk2_data, b"\r\n"]);
        let chunk3 = join(&[chunk3_meta, &chunk3_data, b"\r\n"]);

        let trailers_block = Bytes::from_static(b"x-amz-checksum-crc32c:sOO8/Q==\r\nx-amz-trailer-signature:d81f82fc3505edab99d459891051a732e8730629a2e4a59689829ca17fe2e435");

        let chunk_results: Vec<Result<Bytes, _>> = vec![Ok(chunk1), Ok(chunk2), Ok(chunk3), Ok(trailers_block)];

        let seed_signature = "106e2a8a18243abcf37539882f36619c00e2dfc72633413f02d3b74544bfeb8e";
        let timestamp = "20130524T000000Z";
        let region = "us-east-1";
        let service = "s3";
        let secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";

        let date = AmzDate::parse(timestamp).unwrap();

        let stream = futures::stream::iter(chunk_results);
        let mut chunked_stream = AwsChunkedStream::new(
            stream,
            seed_signature.into(),
            date,
            region.into(),
            service.into(),
            secret_access_key.into(),
            decoded_content_length,
            false,
        );

        let ans1 = chunked_stream.next().await.unwrap();
        assert_eq!(ans1.unwrap(), chunk1_data.as_slice());

        let ans2 = chunked_stream.next().await.unwrap();
        assert_eq!(ans2.unwrap(), chunk2_data.as_slice());

        // No more data after verifying trailers
        assert!(chunked_stream.next().await.is_none());

        // Export trailers via handle
        let handle = chunked_stream.trailing_headers_handle();
        let trailers = handle.take().expect("trailers present");
        assert_eq!(trailers.len(), 1);
        let v = trailers.get("x-amz-checksum-crc32c").unwrap();
        assert_eq!(v, &HeaderValue::from_static("sOO8/Q=="));
    }

    #[tokio::test]
    async fn unsigned_payload_with_trailer_minimal() {
        // Construct a minimal unsigned aws-chunked stream: two data chunks and a 0 chunk, then trailers block.
        // Here we compute signatures using existing test vectors in sig_v4::methods.rs indirectly by using
        // a known seed and the create_trailer_string_to_sign path inside AwsChunkedStream.

        // For unsigned per-chunk mode, meta lines are plain sizes without extensions.
        let chunk1_meta = b"3\r\n"; // size 3
        let chunk2_meta = b"4\r\n"; // size 4
        let chunk3_meta = b"0\r\n"; // last chunk

        let chunk1_data = b"abc";
        let chunk2_data = b"defg";
        let decoded_content_length = chunk1_data.len() + chunk2_data.len();

        let chunk1 = join(&[chunk1_meta, chunk1_data.as_ref(), b"\r\n"]);
        let chunk2 = join(&[chunk2_meta, chunk2_data.as_ref(), b"\r\n"]);
        let chunk3 = join(&[chunk3_meta, b"\r\n"]); // 0-size chunk is followed by CRLF before trailers

        // Build trailers. We'll compute the expected trailer signature using the same algorithm as the stream.
        let seed_signature = "106e2a8a18243abcf37539882f36619c00e2dfc72633413f02d3b74544bfeb8e"; // from example seed for trailer flows
        let timestamp = "20130524T000000Z";
        let region = "us-east-1";
        let service = "s3";
        let secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
        let date = AmzDate::parse(timestamp).unwrap();

        // Canonical trailers: one additional header besides x-amz-trailer-signature
        let canonical = b"x-amz-meta-foo:bar\n".to_vec();
        let string_to_sign = create_trailer_string_to_sign(&date, region, service, seed_signature, &canonical);
        let sig = sig_v4::calculate_signature(&string_to_sign, &SecretKey::from(secret_access_key), &date, region, service);
        let trailers_block = Bytes::from(format!("x-amz-meta-foo: bar\r\nx-amz-trailer-signature:{sig}"));

        let chunk_results: Vec<Result<Bytes, _>> = vec![Ok(chunk1), Ok(chunk2), Ok(chunk3), Ok(trailers_block)];

        let stream = futures::stream::iter(chunk_results);
        let mut chunked_stream = AwsChunkedStream::new(
            stream,
            seed_signature.into(),
            date,
            region.into(),
            service.into(),
            secret_access_key.into(),
            decoded_content_length,
            true, // unsigned
        );

        let ans1 = chunked_stream.next().await.unwrap();
        assert_eq!(ans1.unwrap(), chunk1_data.as_slice());

        let ans2 = chunked_stream.next().await.unwrap();
        assert_eq!(ans2.unwrap(), chunk2_data.as_slice());

        // No more data after verifying trailers
        assert!(chunked_stream.next().await.is_none());

        let handle = chunked_stream.trailing_headers_handle();
        let trailers = handle.take().expect("trailers present");
        assert_eq!(trailers.len(), 1);
        let v = trailers.get("x-amz-meta-foo").unwrap();
        assert_eq!(v, &HeaderValue::from_static("bar"));
    }

    #[tokio::test]
    async fn unsigned_payload_with_trailer_no_signature() {
        // unsigned mode with trailers present but no x-amz-trailer-signature
        let chunk1_meta = b"3\r\n"; // size 3
        let chunk2_meta = b"0\r\n"; // last chunk

        let chunk1_data = b"xyz";
        let decoded_content_length = chunk1_data.len();

        let chunk1 = join(&[chunk1_meta, chunk1_data.as_ref(), b"\r\n"]);
        let chunk2 = join(&[chunk2_meta, b"\r\n"]);

        // Trailers without signature
        let trailers_block = Bytes::from_static(b"x-amz-meta-a: 1\r\nx-amz-meta-b: 2");

        let chunk_results: Vec<Result<Bytes, _>> = vec![Ok(chunk1), Ok(chunk2), Ok(trailers_block)];

        let seed_signature = "deadbeef"; // not used for unsigned per-chunk, but used to build ctx
        let timestamp = "20130524T000000Z";
        let region = "us-east-1";
        let service = "s3";
        let secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
        let date = AmzDate::parse(timestamp).unwrap();

        let stream = futures::stream::iter(chunk_results);
        let mut chunked_stream = AwsChunkedStream::new(
            stream,
            seed_signature.into(),
            date,
            region.into(),
            service.into(),
            secret_access_key.into(),
            decoded_content_length,
            true, // unsigned
        );

        let ans1 = chunked_stream.next().await.unwrap();
        assert_eq!(ans1.unwrap(), chunk1_data.as_slice());

        // No more data after verifying trailers (which contain no signature in unsigned mode)
        assert!(chunked_stream.next().await.is_none());

        let handle = chunked_stream.trailing_headers_handle();
        let trailers = handle.take().expect("trailers present");
        assert_eq!(trailers.len(), 2);
        assert_eq!(trailers.get("x-amz-meta-a").unwrap(), &HeaderValue::from_static("1"));
        assert_eq!(trailers.get("x-amz-meta-b").unwrap(), &HeaderValue::from_static("2"));
    }
}
+2 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ use super::OrderedQs;
use crate::HttpRequest;
use crate::auth::Credentials;
use crate::path::S3Path;
use crate::protocol::TrailingHeaders;
use crate::stream::VecByteStream;

use hyper::HeaderMap;
@@ -34,6 +35,7 @@ pub(crate) struct S3Extensions {
    pub credentials: Option<Credentials>,
    pub region: Option<String>,
    pub service: Option<String>,
    pub trailing_headers: Option<TrailingHeaders>,
}

impl From<HttpRequest> for Request {
+1 −0
Original line number Diff line number Diff line
@@ -47,3 +47,4 @@ pub use self::protocol::HttpRequest;
pub use self::protocol::HttpResponse;
pub use self::protocol::S3Request;
pub use self::protocol::S3Response;
pub use self::protocol::TrailingHeaders;
+4 −0
Original line number Diff line number Diff line
@@ -62,6 +62,7 @@ fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
    let credentials = req.s3ext.credentials.take();
    let region = req.s3ext.region.take();
    let service = req.s3ext.service.take();
    let trailing_headers = req.s3ext.trailing_headers.take();

    S3Request {
        input,
@@ -72,6 +73,7 @@ fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
        credentials,
        region,
        service,
        trailing_headers,
    }
}

@@ -303,6 +305,7 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>

                multipart: None,
                transformed_body: None,
                trailing_headers: None,
            };

            let credentials = scx.check().await?;
@@ -311,6 +314,7 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
            transformed_body = scx.transformed_body;

            req.s3ext.multipart = scx.multipart;
            req.s3ext.trailing_headers = scx.trailing_headers;

            match credentials {
                Some(cred) => {
+77 −30

File changed.

Preview size limit exceeded, changes collapsed.

Loading