Unverified Commit 50abbeaa authored by Nugine's avatar Nugine
Browse files

s3s: dto: event stream

parent fbdefac6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ base64-simd = "0.8.0"
bytes = "1.3.0"
bytestring = "1.2.0"
chrono = { version = "0.4.23", default-features = false }
crc32fast = "1.3.2"
futures = { version = "0.3.25", default-features = false, features = ["std"] }
hex-simd = "0.8.0"
hmac = "0.12.1"
+234 −2
Original line number Diff line number Diff line
use super::SelectObjectContentEvent;
use super::{ContinuationEvent, EndEvent, ProgressEvent, RecordsEvent, StatsEvent};

use crate::stream::ByteStream;
use crate::stream::DynByteStream;
use crate::S3Error;
use crate::S3Result;
use crate::StdError;
use crate::{xml, S3ErrorCode};

use std::fmt;
use std::num::TryFromIntError;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::task::ready;
use std::task::{Context, Poll};

use bytes::BufMut;
use bytes::Bytes;
use futures::Stream;
use smallvec::SmallVec;
use tracing::debug;

pub struct SelectObjectContentEventStream {
    inner: Pin<Box<dyn Stream<Item = S3Result<SelectObjectContentEvent>> + Send + Sync + 'static>>,
@@ -20,6 +31,11 @@ impl SelectObjectContentEventStream {
    {
        Self { inner: Box::pin(stream) }
    }

    #[must_use]
    pub fn into_byte_stream(self) -> DynByteStream {
        Box::pin(Wrapper(self))
    }
}

impl Stream for SelectObjectContentEventStream {
@@ -41,3 +57,219 @@ impl fmt::Debug for SelectObjectContentEventStream {
            .finish_non_exhaustive()
    }
}

struct Wrapper(SelectObjectContentEventStream);

impl Stream for Wrapper {
    type Item = Result<Bytes, StdError>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match ready!(Pin::new(&mut self.0).poll_next(cx)) {
            Some(ev) => Poll::Ready(Some(event_into_bytes(ev).map_err(|e| Box::new(e) as StdError))),
            None => Poll::Ready(None),
        }
    }
}

impl ByteStream for Wrapper {}

fn event_into_bytes(ev: S3Result<SelectObjectContentEvent>) -> Result<Bytes, SerError> {
    match ev {
        Ok(event) => event.into_message().serialize(),
        Err(err) => {
            debug!(?err, "SelectObjectContentEventStream: Request Level Error");
            request_level_error(&err).serialize()
        }
    }
}

struct Message {
    headers: SmallVec<[Header; 4]>,
    payload: Option<Bytes>,
}

struct Header {
    name: Bytes,
    value: Bytes,
}

#[derive(Debug, thiserror::Error)]
enum SerError {
    #[error("Message Serialization: LengthOverflow")]
    LengthOverflow,

    #[error("Message Serialization: IntOverflow: {0}")]
    IntOverflow(#[from] TryFromIntError),
}

impl Message {
    /// <https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html>
    fn serialize(self) -> Result<Bytes, SerError> {
        let total_byte_length: u32;
        let headers_byte_length: u32;
        {
            let headers_len = self.headers.iter().try_fold(0, |mut acc: usize, h| {
                acc = acc.checked_add(1 + 1 + 2)?;
                acc = acc.checked_add(h.name.len())?;
                acc = acc.checked_add(h.value.len())?;
                Some(acc)
            });

            let total_len = headers_len
                .and_then(|acc| acc.checked_add(4 + 4 + 4 + 4))
                .and_then(|acc| self.payload.as_ref().and_then(|payload| acc.checked_add(payload.len())));

            total_byte_length = u32::try_from(total_len.ok_or(SerError::LengthOverflow)?)?;
            headers_byte_length = u32::try_from(headers_len.ok_or(SerError::LengthOverflow)?)?;
        }

        let mut buf: Vec<u8> = Vec::with_capacity(total_byte_length as usize);
        buf.put_u32(total_byte_length);
        buf.put_u32(headers_byte_length);

        let prelude_crc = crc32fast::hash(&buf);
        buf.put_u32(prelude_crc);

        for h in &self.headers {
            let header_name_byte_length = u8::try_from(h.name.len())?;
            let value_string_byte_length = u16::try_from(h.value.len())?;

            buf.put_u8(header_name_byte_length);
            buf.put(&*h.name);

            buf.put_u8(7);
            buf.put_u16(value_string_byte_length);
            buf.put(&*h.value);
        }

        if let Some(payload) = self.payload.as_deref() {
            buf.put(payload);
        }

        let message_crc = crc32fast::hash(&buf);
        buf.put_u32(message_crc);

        Ok(buf.into())
    }
}

impl SelectObjectContentEvent {
    fn into_message(self) -> Message {
        match self {
            SelectObjectContentEvent::Cont(e) => e.into_message(),
            SelectObjectContentEvent::End(e) => e.into_message(),
            SelectObjectContentEvent::Progress(e) => e.into_message(),
            SelectObjectContentEvent::Records(e) => e.into_message(),
            SelectObjectContentEvent::Stats(e) => e.into_message(),
        }
    }
}

const EVENT_TYPE: &str = ":event-type";
const MESSAGE_TYPE: &str = ":message-type";
const CONTENT_TYPE: &str = ":content-type";

impl ContinuationEvent {
    fn into_message(self) -> Message {
        let headers = const_headers(&[
            (EVENT_TYPE, "Cont"),    //
            (MESSAGE_TYPE, "event"), //
        ]);
        let payload = None;
        Message { headers, payload }
    }
}

impl EndEvent {
    fn into_message(self) -> Message {
        let headers = const_headers(&[
            (EVENT_TYPE, "End"),     //
            (MESSAGE_TYPE, "event"), //
        ]);
        let payload = None;
        Message { headers, payload }
    }
}

impl ProgressEvent {
    fn into_message(self) -> Message {
        let headers = const_headers(&[
            (EVENT_TYPE, "Progress"),   //
            (CONTENT_TYPE, "text/xml"), //
            (MESSAGE_TYPE, "event"),    //
        ]);
        let payload = self.details.as_ref().map(xml_payload);
        Message { headers, payload }
    }
}

impl RecordsEvent {
    fn into_message(self) -> Message {
        let headers = const_headers(&[
            (EVENT_TYPE, "Records"),                    //
            (CONTENT_TYPE, "application/octet-stream"), //
            (MESSAGE_TYPE, "event"),                    //
        ]);
        let payload = self.payload;
        Message { headers, payload }
    }
}

impl StatsEvent {
    fn into_message(self) -> Message {
        let headers = const_headers(&[
            (EVENT_TYPE, "Stats"),      //
            (CONTENT_TYPE, "text/xml"), //
            (MESSAGE_TYPE, "event"),    //
        ]);
        let payload = self.details.as_ref().map(xml_payload);
        Message { headers, payload }
    }
}

fn const_headers(hs: &'static [(&'static str, &'static str)]) -> SmallVec<[Header; 4]> {
    let mut ans = SmallVec::with_capacity(hs.len());
    for (name, value) in hs {
        ans.push(header(static_str(name), static_str(value)));
    }
    ans
}

fn xml_payload<T: xml::Serialize>(val: &T) -> Bytes {
    let mut buf = Vec::with_capacity(256);
    {
        let mut ser = xml::Serializer::new(&mut buf);
        ser.decl()
            .and_then(|_| val.serialize(&mut ser))
            .expect("infallible serialization");
    }
    buf.into()
}

fn request_level_error(e: &S3Error) -> Message {
    let code = match e.code().as_static_str() {
        Some(s) => static_str(s),
        None => match e.code() {
            S3ErrorCode::Custom(s) => s.as_bytes().clone(),
            _ => unreachable!(),
        },
    };

    let message = e.message().map_or_else(Bytes::new, |s| Bytes::copy_from_slice(s.as_bytes()));

    let mut headers = SmallVec::with_capacity(3);
    headers.push(header(static_str(":error-code"), code));
    headers.push(header(static_str(":error-message"), message));
    headers.push(header(static_str(MESSAGE_TYPE), static_str("error")));
    Message { headers, payload: None }
}

#[inline(always)]
fn static_str(s: &'static str) -> Bytes {
    Bytes::from_static(s.as_bytes())
}

#[inline(always)]
fn header(name: Bytes, value: Bytes) -> Header {
    Header { name, value }
}