Unverified Commit c1fb5a2e authored by Nugine's avatar Nugine
Browse files

s3s: http: ordered_headers: get_unique

parent 36bda786
Loading
Loading
Loading
Loading
+38 −9
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ use hyper::http::HeaderValue;
use hyper::HeaderMap;
use smallvec::SmallVec;

use crate::utils::stable_sort_by_first;

/// Immutable http header container
#[derive(Debug)]
pub struct OrderedHeaders<'a> {
@@ -20,9 +22,13 @@ impl<'a> OrderedHeaders<'a> {
    #[cfg(test)]
    #[must_use]
    pub fn from_slice_unchecked(slice: &[(&'a str, &'a str)]) -> Self {
        for (name, _) in slice {
            let is_valid = |c: u8| c == b'-' || c.is_ascii_lowercase() || c.is_ascii_digit();
            assert!(name.as_bytes().iter().copied().all(is_valid));
        }
        let mut headers = SmallVec::new();
        headers.extend_from_slice(slice);
        headers.sort_unstable();
        stable_sort_by_first(&mut headers);
        Self { headers }
    }

@@ -36,20 +42,43 @@ impl<'a> OrderedHeaders<'a> {
        for (name, value) in map.iter() {
            headers.push((name.as_str(), value.to_str()?));
        }
        headers.sort_unstable();
        stable_sort_by_first(&mut headers);

        Ok(Self { headers })
    }

    fn get_pair(&self, name: &'_ str) -> Option<(&'a str, &'a str)> {
        let headers = self.headers.as_slice();
        let idx = headers.binary_search_by_key(&name, |&(n, _)| n).ok()?;
        headers.get(idx).copied()
    fn get_all_pairs(&self, name: &str) -> impl Iterator<Item = (&'a str, &'a str)> + '_ {
        let slice = self.headers.as_slice();

        let lower_bound = slice.partition_point(|x| x.0 < name);
        let upper_bound = slice.partition_point(|x| x.0 <= name);

        slice[lower_bound..upper_bound].iter().copied()
    }

    pub fn get_all(&self, name: impl AsRef<str>) -> impl Iterator<Item = &'a str> + '_ {
        self.get_all_pairs(name.as_ref()).map(|x| x.1)
    }

    fn get_unique_pair(&self, name: &'_ str) -> Option<(&'a str, &'a str)> {
        let slice = self.headers.as_slice();
        let lower_bound = slice.partition_point(|x| x.0 < name);

        let mut iter = slice[lower_bound..].iter().copied();
        let pair = iter.next()?;

        if let Some(following) = iter.next() {
            if following.0 == name {
                return None;
            }
        }

        (pair.0 == name).then_some(pair)
    }

    /// Gets header value by name. Time `O(logn)`
    pub fn get(&self, name: impl AsRef<str>) -> Option<&'a str> {
        self.get_pair(name.as_ref()).map(|(_, v)| v)
    pub fn get_unique(&self, name: impl AsRef<str>) -> Option<&'a str> {
        self.get_unique_pair(name.as_ref()).map(|(_, v)| v)
    }

    /// Finds headers by names. Time `O(mlogn)`
@@ -57,7 +86,7 @@ impl<'a> OrderedHeaders<'a> {
    pub fn find_multiple(&self, names: &[impl AsRef<str>]) -> Self {
        let mut headers: SmallVec<[(&'a str, &'a str); 16]> = SmallVec::new();
        for name in names {
            if let Some(pair) = self.get_pair(name.as_ref()) {
            for pair in self.get_all_pairs(name.as_ref()) {
                headers.push(pair);
            }
        }
+24 −5
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ impl OrderedQs {
        let lower_bound = qs.partition_point(|x| x.0.as_str() < name);

        let mut iter = qs[lower_bound..].iter();
        let val = iter.next()?.1.as_str();
        let pair = iter.next()?;

        if let Some(following) = iter.next() {
            if following.0 == name {
@@ -67,7 +67,7 @@ impl OrderedQs {
            }
        }

        Some(val)
        (pair.0.as_str() == name).then_some(pair.1.as_str())
    }
}

@@ -83,8 +83,27 @@ mod tests {

    #[test]
    fn tag() {
        {
            let query = "tagging";
            let qs = OrderedQs::parse(query).unwrap();
            assert_eq!(qs.as_ref(), &[("tagging".to_owned(), "".to_owned())]);

            assert_eq!(qs.get_unique("taggin"), None);
            assert_eq!(qs.get_unique("tagging"), Some(""));
            assert_eq!(qs.get_unique("taggingg"), None);
        }

        {
            let query = "tagging&tagging";
            let qs = OrderedQs::parse(query).unwrap();
            assert_eq!(
                qs.as_ref(),
                &[("tagging".to_owned(), "".to_owned()), ("tagging".to_owned(), "".to_owned())]
            );

            assert_eq!(qs.get_unique("taggin"), None);
            assert_eq!(qs.get_unique("tagging"), None);
            assert_eq!(qs.get_unique("taggingg"), None);
        }
    }
}
+5 −5
Original line number Diff line number Diff line
@@ -90,7 +90,7 @@ fn extract_headers(req: &Request) -> S3Result<OrderedHeaders<'_>> {
}

fn extract_mime(hs: &OrderedHeaders<'_>) -> S3Result<Option<Mime>> {
    let Some(content_type) = hs.get(crate::header::CONTENT_TYPE) else { return Ok(None) };
    let Some(content_type) = hs.get_unique(crate::header::CONTENT_TYPE) else { return Ok(None) };
    match content_type.parse::<Mime>() {
        Ok(x) => Ok(Some(x)),
        Err(e) => Err(invalid_request!(e, "invalid content type")),
@@ -98,7 +98,7 @@ fn extract_mime(hs: &OrderedHeaders<'_>) -> S3Result<Option<Mime>> {
}

fn extract_amz_content_sha256<'a>(hs: &'_ OrderedHeaders<'a>) -> S3Result<Option<AmzContentSha256<'a>>> {
    let Some(val) = hs.get(crate::header::X_AMZ_CONTENT_SHA256) else { return Ok(None) };
    let Some(val) = hs.get_unique(crate::header::X_AMZ_CONTENT_SHA256) else { return Ok(None) };
    match AmzContentSha256::parse(val) {
        Ok(x) => Ok(Some(x)),
        Err(e) => {
@@ -111,7 +111,7 @@ fn extract_amz_content_sha256<'a>(hs: &'_ OrderedHeaders<'a>) -> S3Result<Option
}

fn extract_authorization_v4<'a>(hs: &'_ OrderedHeaders<'a>) -> S3Result<Option<AuthorizationV4<'a>>> {
    let Some(val) = hs.get(crate::header::AUTHORIZATION) else { return Ok(None) };
    let Some(val) = hs.get_unique(crate::header::AUTHORIZATION) else { return Ok(None) };
    match AuthorizationV4::parse(val) {
        Ok(x) => Ok(Some(x)),
        Err(e) => Err(invalid_request!(e, "invalid header: authorization")),
@@ -119,7 +119,7 @@ fn extract_authorization_v4<'a>(hs: &'_ OrderedHeaders<'a>) -> S3Result<Option<A
}

fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
    let Some(val) = hs.get(crate::header::X_AMZ_DATE) else { return Ok(None) };
    let Some(val) = hs.get_unique(crate::header::X_AMZ_DATE) else { return Ok(None) };
    match AmzDate::parse(val) {
        Ok(x) => Ok(Some(x)),
        Err(e) => Err(invalid_request!(e, "invalid header: x-amz-date")),
@@ -127,7 +127,7 @@ fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
}

fn extract_decoded_content_length(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<usize>> {
    let Some(val) = hs.get(crate::header::X_AMZ_DECODED_CONTENT_LENGTH) else { return Ok(None) };
    let Some(val) = hs.get_unique(crate::header::X_AMZ_DECODED_CONTENT_LENGTH) else { return Ok(None) };
    match atoi::atoi::<usize>(val.as_bytes()) {
        Some(x) => Ok(Some(x)),
        None => Err(invalid_request!("invalid header: x-amz-decoded-content-length")),
+4 −1
Original line number Diff line number Diff line
@@ -93,6 +93,9 @@ pub fn is_base64_encoded(bytes: &[u8]) -> bool {
/// `Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>`
pub type SyncBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;

pub fn stable_sort_by_first(v: &mut [(String, String)]) {
pub fn stable_sort_by_first<T>(v: &mut [(T, T)])
where
    T: Ord,
{
    v.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0));
}