Commit 1746e263 authored by Nugine's avatar Nugine
Browse files

fix(s3s/signature): host header fallback on HTTP2

parent 6faf16ec
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -91,6 +91,29 @@ impl<'a> OrderedHeaders<'a> {
        }
        Self { headers }
    }

    /// Finds headers by names. Time `O(mlogn)`
    #[must_use]
    pub fn find_multiple_with_on_missing(
        &self,
        names: &'a [impl AsRef<str>],
        on_missing: impl Fn(&'a str) -> Option<&'a str>,
    ) -> Self {
        let mut headers: Vec<(&'a str, &'a str)> = Vec::new();
        for name in names {
            let mut has_value = false;
            for pair in self.get_all_pairs(name.as_ref()) {
                headers.push(pair);
                has_value = true;
            }
            if !has_value {
                if let Some(value) = on_missing(name.as_ref()) {
                    headers.push((name.as_ref(), value));
                }
            }
        }
        Self { headers }
    }
}

impl<'a> AsRef<[(&'a str, &'a str)]> for OrderedHeaders<'a> {
+2 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ use hyper::http::Extensions;
use hyper::http::HeaderValue;

pub struct Request {
    pub version: http::Version,
    pub method: Method,
    pub uri: Uri,
    pub headers: HeaderMap<HeaderValue>,
@@ -39,6 +40,7 @@ impl From<HttpRequest> for Request {
    fn from(req: HttpRequest) -> Self {
        let (parts, body) = req.into_parts();
        Self {
            version: parts.version,
            method: parts.method,
            uri: parts.uri,
            headers: parts.headers,
+2 −0
Original line number Diff line number Diff line
@@ -238,6 +238,7 @@ enum Prepare {
}

#[allow(clippy::too_many_lines)]
#[tracing::instrument(level = "debug", skip_all, err)]
async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare> {
    let s3_path;
    let mut content_length;
@@ -285,6 +286,7 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
            let mut scx = SignatureContext {
                auth: ccx.auth,

                req_version: req.version,
                req_method: &req.method,
                req_uri: &req.uri,
                req_body: &mut req.body,
+15 −1
Original line number Diff line number Diff line
@@ -54,6 +54,7 @@ fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
pub struct SignatureContext<'a> {
    pub auth: Option<&'a dyn S3Auth>,

    pub req_version: ::http::Version,
    pub req_method: &'a Method,
    pub req_uri: &'a Uri,
    pub req_body: &'a mut Body,
@@ -299,8 +300,21 @@ impl SignatureContext<'_> {
            let uri_path = &self.decoded_uri_path;
            let query_strings: &[(String, String)] = self.qs.as_ref().map_or(&[], AsRef::as_ref);

            // FIXME: throw error if any signed header is not in the request
            // `host` header need to be special handled

            // here requires that `auth.signed_headers` is sorted
            let headers = self.hs.find_multiple(&authorization.signed_headers);
            let headers = self.hs.find_multiple_with_on_missing(&authorization.signed_headers, |name| {
                // HTTP/2 replaces `host` header with `:authority`
                // but `:authority` is not in the request headers
                // so we need to add it back if `host` is in the signed headers
                if name == "host" && self.req_version == ::http::Version::HTTP_2 {
                    if let Some(authority) = self.req_uri.authority() {
                        return Some(authority.as_str());
                    }
                }
                None
            });

            let canonical_request = if is_stream {
                let payload = sig_v4::Payload::MultipleChunks;
+3 −3
Original line number Diff line number Diff line
@@ -192,8 +192,8 @@ mod tests {
        print_future_size!(S3Service::call_owned);

        // In case the futures are made too large accidentally
        assert!(output_size(&crate::ops::call) <= 1500);
        assert!(output_size(&S3Service::call) <= 2800);
        assert!(output_size(&S3Service::call_owned) <= 3100);
        assert!(output_size(&crate::ops::call) <= 1600);
        assert!(output_size(&S3Service::call) <= 2900);
        assert!(output_size(&S3Service::call_owned) <= 3200);
    }
}