Commit 58fff185 authored by Nugine's avatar Nugine
Browse files

feat(s3s/route): support custom route

parent d7e8efc2
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ pub mod dto;
pub mod header;
pub mod host;
pub mod path;
pub mod route;
pub mod service;
pub mod stream;
pub mod xml;
+44 −10
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ use crate::http::{OrderedHeaders, OrderedQs};
use crate::http::{Request, Response};
use crate::path::{ParseS3PathError, S3Path};
use crate::request::S3Request;
use crate::route::S3Route;
use crate::s3_trait::S3;
use crate::stream::aggregate_unlimited;
use crate::stream::VecByteStream;
@@ -50,6 +51,7 @@ pub struct CallContext<'a> {
    pub host: Option<&'a dyn S3Host>,
    pub auth: Option<&'a dyn S3Auth>,
    pub access: Option<&'a dyn S3Access>,
    pub route: Option<&'a dyn S3Route>,
}

fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
@@ -177,7 +179,7 @@ fn fmt_content_length(len: usize) -> http::HeaderValue {
}

pub async fn call(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Response> {
    let op = match prepare(req, ccx).await {
    let prep = match prepare(req, ccx).await {
        Ok(op) => op,
        Err(err) => {
            debug!(?err, "failed to prepare");
@@ -185,19 +187,45 @@ pub async fn call(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Response
        }
    };

    let resp = match op.call(ccx, req).await {
        Ok(resp) => resp,
    match prep {
        Prepare::S3(op) => {
            match op.call(ccx, req).await {
                Ok(resp) => {
                    Ok(resp) //
                }
                Err(err) => {
                    debug!(op = %op.name(), ?err, "op returns error");
                    return serialize_error(err, false);
                }
    };
            }
        }
        Prepare::CustomRoute => {
            let body = mem::take(&mut req.body);
            let s3_req = build_s3_request(body, req);
            let route = ccx.route.unwrap();
            match route.call(s3_req).await {
                Ok(s3_resp) => Ok(Response {
                    status: s3_resp.output.0,
                    headers: s3_resp.headers,
                    body: s3_resp.output.1,
                    extensions: s3_resp.extensions,
                }),
                Err(err) => {
                    debug!(?err, "custom route returns error");
                    return serialize_error(err, false);
                }
            }
        }
    }
}

    Ok(resp)
enum Prepare {
    S3(&'static dyn Operation),
    CustomRoute,
}

#[allow(clippy::too_many_lines)]
async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<&'static dyn Operation> {
async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare> {
    let s3_path;
    let mut content_length;
    {
@@ -288,6 +316,12 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<&'static
        debug!(?body_changed, ?decoded_content_length, ?has_multipart);
    }

    if let Some(route) = ccx.route {
        if route.is_match(&req.method, &req.uri, &req.headers, &mut req.extensions) {
            return Ok(Prepare::CustomRoute);
        }
    }

    let (op, needs_full_body) = 'resolve: {
        if let Some(multipart) = &mut req.s3ext.multipart {
            if req.method == Method::POST {
@@ -343,5 +377,5 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<&'static
        extract_full_body(content_length, &mut req.body).await?;
    }

    Ok(op)
    Ok(Prepare::S3(op))
}
+45 −0
Original line number Diff line number Diff line
use crate::Body;
use crate::S3Request;
use crate::S3Response;
use crate::S3Result;

use hyper::http::Extensions;
use hyper::HeaderMap;
use hyper::Method;
use hyper::StatusCode;
use hyper::Uri;

#[async_trait::async_trait]
pub trait S3Route: Send + Sync + 'static {
    fn is_match(&self, method: &Method, uri: &Uri, headers: &HeaderMap, extensions: &mut Extensions) -> bool;

    async fn call(&self, req: S3Request<Body>) -> S3Result<S3Response<(StatusCode, Body)>>;
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::header;

    pub struct AssumeRole {}

    #[async_trait::async_trait]
    impl S3Route for AssumeRole {
        fn is_match(&self, method: &Method, uri: &Uri, headers: &HeaderMap, _: &mut Extensions) -> bool {
            if method == Method::POST && uri.path() == "/" {
                if let Some(val) = headers.get(header::CONTENT_TYPE) {
                    if val.as_bytes() == b"application/x-www-form-urlencoded" {
                        return true;
                    }
                }
            }
            false
        }

        async fn call(&self, _: S3Request<Body>) -> S3Result<S3Response<(StatusCode, Body)>> {
            tracing::debug!("call AssumeRole");
            return Err(s3_error!(NotImplemented));
        }
    }
}
+10 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ use crate::auth::S3Auth;
use crate::error::{S3Error, S3Result};
use crate::host::S3Host;
use crate::http::{Body, Request};
use crate::route::S3Route;
use crate::s3_trait::S3;

use std::convert::Infallible;
@@ -19,6 +20,7 @@ pub struct S3ServiceBuilder {
    host: Option<Box<dyn S3Host>>,
    auth: Option<Box<dyn S3Auth>>,
    access: Option<Box<dyn S3Access>>,
    route: Option<Box<dyn S3Route>>,
}

impl S3ServiceBuilder {
@@ -29,6 +31,7 @@ impl S3ServiceBuilder {
            host: None,
            auth: None,
            access: None,
            route: None,
        }
    }

@@ -44,6 +47,10 @@ impl S3ServiceBuilder {
        self.access = Some(Box::new(access));
    }

    pub fn set_route(&mut self, route: impl S3Route) {
        self.route = Some(Box::new(route));
    }

    #[must_use]
    pub fn build(self) -> S3Service {
        S3Service {
@@ -51,6 +58,7 @@ impl S3ServiceBuilder {
            host: self.host,
            auth: self.auth,
            access: self.access,
            route: self.route,
        }
    }
}
@@ -60,6 +68,7 @@ pub struct S3Service {
    host: Option<Box<dyn S3Host>>,
    auth: Option<Box<dyn S3Auth>>,
    access: Option<Box<dyn S3Access>>,
    route: Option<Box<dyn S3Route>>,
}

impl S3Service {
@@ -78,6 +87,7 @@ impl S3Service {
            host: self.host.as_deref(),
            auth: self.auth.as_deref(),
            access: self.access.as_deref(),
            route: self.route.as_deref(),
        };
        let result = crate::ops::call(&mut req, &ccx).await.map(Into::into);