Unverified Commit 12b2dfe9 authored by Nugine's avatar Nugine
Browse files

s3s: ops: small refactor

parent 37f76129
Loading
Loading
Loading
Loading
+88 −28
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@ use crate::stream::aggregate_unlimited;
use crate::stream::VecByteStream;
use crate::utils::is_base64_encoded;

use std::borrow::Cow;
use std::mem;
use std::ops::Not;

@@ -42,8 +41,8 @@ fn serialize_error(x: S3Error) -> S3Result<Response> {
    Ok(res)
}

fn extract_s3_path(req: &mut Request, base_domain: Option<&str>) -> S3Result<S3Path> {
    let uri_path = urlencoding::decode(req.uri().path()).map_err(|_| S3ErrorCode::InvalidURI)?;
fn extract_s3_path(req: &Request, decoded_uri_path: &str, base_domain: Option<&str>) -> S3Result<S3Path> {
    let uri_path = decoded_uri_path;

    let result = match (base_domain, req.headers().get(crate::header::names::HOST)) {
        (Some(base_domain), Some(val)) => {
@@ -51,11 +50,11 @@ fn extract_s3_path(req: &mut Request, base_domain: Option<&str>) -> S3Result<S3P
            let host = val.to_str().map_err(on_err)?;

            debug!(?base_domain, ?host, ?uri_path, "parsing virtual-hosted-style request");
            crate::path::parse_virtual_hosted_style(base_domain, host, &uri_path)
            crate::path::parse_virtual_hosted_style(base_domain, host, uri_path)
        }
        _ => {
            debug!(?uri_path, "parsing path-style request");
            crate::path::parse_path_style(&uri_path)
            crate::path::parse_path_style(uri_path)
        }
    };

@@ -159,18 +158,40 @@ async fn extract_full_body(req: &Request, body: &mut Body) -> S3Result<Bytes> {
    Ok(bytes)
}

#[allow(clippy::declare_interior_mutable_const)]
fn fmt_content_length(len: usize) -> http::HeaderValue {
    const ZERO: http::HeaderValue = http::HeaderValue::from_static("0");
    if len > 0 {
        crate::utils::fmt_usize(len, |s| http::HeaderValue::try_from(s).unwrap())
    } else {
        ZERO
    }
}

pub async fn call(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<Response> {
    match call_inner(req, s3, auth, base_domain).await {
    let op = match prepare(req, auth, base_domain).await {
        Ok(op) => op,
        Err(err) => {
            debug!(?err, "failed to prepare");
            return serialize_error(err);
        }
    };

    match op.call(s3, req).await {
        Ok(res) => Ok(res),
        Err(err) => {
            debug!(?err, "op returns error");
            debug!(op = %op.name(), ?err, "op returns error");
            serialize_error(err)
        }
    }
}

async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<Response> {
    let s3_path = extract_s3_path(req, base_domain)?;
async fn prepare(req: &mut Request, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<&'static dyn Operation> {
    let decoded_uri_path = urlencoding::decode(req.uri().path())
        .map_err(|_| S3ErrorCode::InvalidURI)?
        .into_owned();

    let s3_path = extract_s3_path(req, &decoded_uri_path, base_domain)?;
    let qs = extract_qs(req)?;

    // check signature
@@ -180,7 +201,6 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
        let headers = extract_headers(req)?;
        let mime = extract_mime(&headers)?;
        let decoded_content_length = extract_decoded_content_length(&headers)?;
        let decoded_uri_path = urlencoding::decode(req.uri().path()).map_err(|_| S3ErrorCode::InvalidURI)?;
        let body_transformed;
        {
            let mut scx = SignatureContext {
@@ -202,13 +222,9 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
            body_transformed = scx.body_transformed;
        }
        if body_transformed {
            // invalidate the original content length
            if let Some(val) = req.headers_mut().get_mut(header::names::CONTENT_LENGTH) {
                let len = decoded_content_length.unwrap_or(0);
                if len > 0 {
                    *val = crate::utils::fmt_usize(len, |s| http::HeaderValue::try_from(s).unwrap())
                } else {
                    *val = http::HeaderValue::from_static("0");
                }
                *val = fmt_content_length(decoded_content_length.unwrap_or(0))
            }
        }
        let has_multipart = multipart.is_some();
@@ -237,7 +253,7 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
        resolve_route(req, &s3_path, qs.as_ref())?
    };

    debug!(op = %op.name(), "resolved route");
    debug!(op = %op.name(), ?s3_path, "resolved route");

    req.extensions_mut().insert(s3_path);
    if let Some(qs) = qs {
@@ -249,12 +265,13 @@ async fn call_inner(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, b
    }
    *req.body_mut() = body;

    op.call(s3, req).await
    Ok(op)
}

struct SignatureContext<'a> {
    auth: Option<&'a dyn S3Auth>,
    req: &'a Request,
    decoded_uri_path: String,
    qs: Option<&'a OrderedQs>,
    headers: OrderedHeaders<'a>,
    mime: Option<Mime>,
@@ -262,7 +279,6 @@ struct SignatureContext<'a> {
    multipart: Option<Multipart>,
    body_transformed: bool,
    decoded_content_length: Option<usize>,
    decoded_uri_path: Cow<'a, str>,
}

impl SignatureContext<'_> {
@@ -293,6 +309,7 @@ impl SignatureContext<'_> {
    async fn check_post_signature(&mut self) -> S3Result<()> {
        let auth = require_auth(self.auth)?;

        let multipart = {
            let mime = self.mime.as_ref().unwrap(); // assume: multipart

            let boundary = mime
@@ -305,6 +322,9 @@ impl SignatureContext<'_> {
                .map_err(|e| s3_error!(e, MalformedPOSTRequest))?;
            self.body_transformed = true;

            multipart
        };

        let info = PostSignatureInfo::extract(&multipart).ok_or_else(|| invalid_request!("missing required multipart fields"))?;

        if is_base64_encoded(info.policy.as_bytes()).not() {
@@ -492,3 +512,43 @@ impl<'a> PostSignatureInfo<'a> {
fn require_auth(auth: Option<&dyn S3Auth>) -> S3Result<&dyn S3Auth> {
    auth.ok_or_else(|| s3_error!(NotImplemented, "This service has no authentication provider"))
}

#[cfg(test)]
mod tests {
    use super::*;

    pub trait OutputSize<A> {
        type Output;
    }

    macro_rules! impl_output_size {
        (($($ty:tt,)*)) => {
            impl<$($ty,)* F, R> OutputSize<($($ty,)*)> for F
            where
                F: Fn($($ty,)*) -> R ,
            {
                type Output = R;
            }
        };
    }

    impl_output_size!(());
    impl_output_size!((A0,));
    impl_output_size!((A0, A1,));
    impl_output_size!((A0, A1, A2,));
    impl_output_size!((A0, A1, A2, A3,));

    #[inline]
    #[must_use]
    pub const fn output_size<F, A>(_: &F) -> usize
    where
        F: OutputSize<A>,
    {
        std::mem::size_of::<F::Output>()
    }

    #[test]
    fn track_future_size() {
        assert_eq!(output_size(&call), 2928);
    }
}