Unverified Commit c15dbdec authored by Nugine's avatar Nugine
Browse files

s3s: virtual-hosted-style requests

parent ee6d0f0a
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -25,9 +25,11 @@ pub fn codegen(model: &smithy::Model, g: &mut Codegen) {
        headers.insert("x-amz-content-sha256");
        headers.insert("x-amz-date");
        headers.insert("authorization");
        headers.insert("host");
    }

    let prelude = [
        "//! Auto generated header name definitions",
        "#![allow(clippy::declare_interior_mutable_const)]",
        "",
        "use hyper::header::HeaderName;",
+11 −1
Original line number Diff line number Diff line
@@ -28,6 +28,9 @@ struct Opt {
    #[clap(long, requires("access-key"))]
    secret_key: Option<String>,

    #[clap(long)]
    domain_name: Option<String>,

    root: PathBuf,
}

@@ -43,18 +46,25 @@ fn setup_tracing() {
#[tokio::main]
async fn main() -> Result {
    setup_tracing();

    let opt = Opt::parse();

    // Setup S3 service
    let fs = FileSystem::new(opt.root)?;
    let mut service = S3Service::new(Box::new(fs));

    // Enable authentication
    if let (Some(ak), Some(sk)) = (opt.access_key, opt.secret_key) {
        let mut auth = SimpleAuth::new();
        auth.register(ak, sk);
        service.set_auth(Box::new(auth));
    }

    // Enable parsing virtual-hosted-style requests
    if let Some(domain_name) = opt.domain_name {
        service.set_base_domain(domain_name);
    }

    // Run server
    let listener = TcpListener::bind((opt.host.as_str(), opt.port))?;
    let server = Server::from_tcp(listener)?.serve(service.into_shared().into_make_service());

+3 −0
Original line number Diff line number Diff line
//! Auto generated header name definitions
#![allow(clippy::declare_interior_mutable_const)]

use hyper::header::HeaderName;
@@ -40,6 +41,8 @@ pub use hyper::header::ACCEPT_RANGES;

pub use hyper::header::AUTHORIZATION;

pub use hyper::header::HOST;

pub const X_AMZ_ABORT_DATE: HeaderName = HeaderName::from_static("x-amz-abort-date");

pub const X_AMZ_ABORT_RULE_ID: HeaderName = HeaderName::from_static("x-amz-abort-rule-id");
+19 −10
Original line number Diff line number Diff line
@@ -39,14 +39,23 @@ fn serialize_error(x: S3Error) -> S3Result<Response> {
    Ok(res)
}

fn extract_s3_path(req: &mut Request) -> S3Result<S3Path> {
    let path = urlencoding::decode(req.uri().path()).map_err(|_| S3ErrorCode::InvalidURI)?;
    let ans = crate::path::parse_path_style(&path).map_err(|err| match err {
        ParseS3PathError::InvalidPath => S3ErrorCode::InvalidURI,
        ParseS3PathError::InvalidBucketName => S3ErrorCode::InvalidBucketName,
        ParseS3PathError::KeyTooLong => S3ErrorCode::KeyTooLongError,
    })?;
    Ok(ans)
fn extract_s3_path(req: &mut Request, base_domain: Option<&str>) -> S3Result<S3Path> {
    let uri_path = urlencoding::decode(req.uri().path()).map_err(|_| S3ErrorCode::InvalidURI)?;

    let result = match (base_domain, req.headers().get(crate::header::names::HOST)) {
        (Some(base_domain), Some(val)) => {
            let on_err = |e| s3_error!(e, InvalidRequest, "invalid header: Host: {val:?}");
            let host = val.to_str().map_err(on_err)?;
            crate::path::parse_virtual_hosted_style(base_domain, host, &uri_path)
        }
        _ => crate::path::parse_path_style(&uri_path),
    };

    result.map_err(|err| match err {
        ParseS3PathError::InvalidPath => s3_error!(InvalidURI),
        ParseS3PathError::InvalidBucketName => s3_error!(InvalidBucketName),
        ParseS3PathError::KeyTooLong => s3_error!(KeyTooLongError),
    })
}

fn extract_qs(req: &mut Request) -> S3Result<Option<OrderedQs>> {
@@ -104,8 +113,8 @@ fn extract_amz_date(hs: &'_ OrderedHeaders<'_>) -> S3Result<Option<AmzDate>> {
    }
}

pub async fn call(s3: &dyn S3, auth: Option<&dyn S3Auth>, req: &mut Request) -> S3Result<Response> {
    let s3_path = extract_s3_path(req)?;
pub async fn call(req: &mut Request, s3: &dyn S3, auth: Option<&dyn S3Auth>, base_domain: Option<&str>) -> S3Result<Response> {
    let s3_path = extract_s3_path(req, base_domain)?;
    let qs = extract_qs(req)?;

    // check signature
+10 −1
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ pub struct S3Service {
    s3: Box<dyn S3>,
    auth: Option<Box<dyn S3Auth>>,
    full_body_limit: u64,
    base_domain: Option<String>,
}

impl S3Service {
@@ -24,6 +25,7 @@ impl S3Service {
            s3,
            auth: None,
            full_body_limit: crate::http::DEFAULT_LENGTH_LIMIT,
            base_domain: None,
        }
    }

@@ -35,6 +37,10 @@ impl S3Service {
        self.full_body_limit = length_limit;
    }

    pub fn set_base_domain(&mut self, base_domain: impl Into<String>) {
        self.base_domain = Some(base_domain.into());
    }

    #[tracing::instrument(
        level = "debug",
        skip(self, req),
@@ -47,7 +53,10 @@ impl S3Service {
            req.extensions_mut().insert(crate::http::LengthLimit(self.full_body_limit));
        }

        let result = crate::ops::call(&*self.s3, self.auth.as_deref(), &mut req).await;
        let s3 = &*self.s3;
        let auth = self.auth.as_deref();
        let base_domain = self.base_domain.as_deref();
        let result = crate::ops::call(&mut req, s3, auth, base_domain).await;

        match result {
            Ok(ref res) => debug!(?res),
Loading