Unverified Commit 5340484c authored by Nugine's avatar Nugine Committed by GitHub
Browse files

feat(s3s/host): add MultiDomain (#179)

* refactor: change domain arg

* feat(s3s/host): add MultiDomain

* fix
parent c94f3147
Loading
Loading
Loading
Loading
+7 −6
Original line number Diff line number Diff line
@@ -5,10 +5,11 @@ use s3s_fs::FileSystem;
use s3s_fs::Result;

use s3s::auth::SimpleAuth;
use s3s::host::SingleDomain;
use s3s::host::MultiDomain;
use s3s::service::S3ServiceBuilder;

use std::io::IsTerminal;
use std::ops::Not;
use std::path::PathBuf;

use tokio::net::TcpListener;
@@ -38,9 +39,9 @@ struct Opt {
    #[arg(long)]
    secret_key: Option<String>,

    /// Domain name used for virtual-hosted-style requests.
    /// Domain names used for virtual-hosted-style requests.
    #[arg(long)]
    domain_name: Option<String>,
    domain: Vec<String>,

    /// Root directory of stored data.
    root: PathBuf,
@@ -70,7 +71,7 @@ fn check_cli_args(opt: &Opt) {
        cmd.error(ErrorKind::MissingRequiredArgument, msg).exit();
    }

    if let Some(ref s) = opt.domain_name {
    for s in &opt.domain {
        if s.contains('/') {
            let msg = format!("expected domain name, found URL-like string: {s:?}");
            cmd.error(ErrorKind::InvalidValue, msg).exit();
@@ -103,8 +104,8 @@ async fn run(opt: Opt) -> Result {
        }

        // Enable parsing virtual-hosted-style requests
        if let Some(domain_name) = opt.domain_name {
            b.set_host(SingleDomain::new(domain_name));
        if opt.domain.is_empty().not() {
            b.set_host(MultiDomain::new(&opt.domain)?);
            info!("virtual-hosted-style requests are enabled");
        }

+1 −1
Original line number Diff line number Diff line
@@ -65,7 +65,7 @@ fn config() -> &'static SdkConfig {
        let service = {
            let mut b = S3ServiceBuilder::new(fs);
            b.set_auth(SimpleAuth::from_single(cred.access_key_id(), cred.secret_access_key()));
            b.set_host(SingleDomain::new(DOMAIN_NAME));
            b.set_host(SingleDomain::new(DOMAIN_NAME).unwrap());
            b.build()
        };

+3 −3
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ struct Opt {
    port: u16,

    #[clap(long)]
    domain_name: Option<String>,
    domain: Option<String>,

    #[clap(long)]
    endpoint_url: String,
@@ -66,8 +66,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
        }

        // Enable parsing virtual-hosted-style requests
        if let Some(domain_name) = opt.domain_name {
            b.set_host(SingleDomain::new(domain_name));
        if let Some(domain) = opt.domain {
            b.set_host(SingleDomain::new(&domain)?);
        }

        b.build()
+229 −11
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@ use crate::error::S3Result;

use std::borrow::Cow;

use rust_utils::default::default;

#[derive(Debug, Clone)]
pub struct VirtualHost<'a> {
    domain: Cow<'a, str>,
@@ -45,16 +47,79 @@ pub trait S3Host: Send + Sync + 'static {
    fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>>;
}

#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum DomainError {
    #[error("The domain is invalid")]
    InvalidDomain,

    #[error("Some subdomains overlap with each other")]
    OverlappingSubdomains,

    #[error("No base domains are specified")]
    ZeroDomains,
}

/// Naive check for a valid domain.
fn is_valid_domain(mut s: &str) -> bool {
    if s.is_empty() {
        return false;
    }

    if let Some((host, port)) = s.split_once(':') {
        if port.is_empty() {
            return false;
        }

        if port.parse::<u16>().is_err() {
            return false;
        }

        s = host;
    }

    for part in s.split('.') {
        if part.is_empty() {
            return false;
        }

        if part.as_bytes().iter().any(|&b| !b.is_ascii_alphanumeric() && b != b'-') {
            return false;
        }
    }

    true
}

fn parse_host_header<'a>(base_domain: &'a str, host: &'a str) -> Option<VirtualHost<'a>> {
    if host == base_domain {
        return Some(VirtualHost::new(base_domain));
    }

    if let Some(bucket) = host.strip_suffix(base_domain).and_then(|h| h.strip_suffix('.')) {
        return Some(VirtualHost::with_bucket(base_domain, bucket));
    };

    None
}

#[derive(Debug)]
pub struct SingleDomain {
    base_domain: String,
}

impl SingleDomain {
    #[must_use]
    pub fn new(base_domain: impl Into<String>) -> Self {
        Self {
            base_domain: base_domain.into(),
    /// Create a new `SingleDomain` with the base domain.
    ///
    /// # Errors
    /// Returns an error if the base domain is invalid.
    pub fn new(base_domain: &str) -> Result<Self, DomainError> {
        if !is_valid_domain(base_domain) {
            return Err(DomainError::InvalidDomain);
        }

        Ok(Self {
            base_domain: base_domain.into(),
        })
    }
}

@@ -62,15 +127,168 @@ impl S3Host for SingleDomain {
    fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>> {
        let base_domain = self.base_domain.as_str();

        if host == base_domain {
            return Ok(VirtualHost::new(base_domain));
        if let Some(vh) = parse_host_header(base_domain, host) {
            return Ok(vh);
        }

        if let Some(bucket) = host.strip_suffix(&self.base_domain).and_then(|h| h.strip_suffix('.')) {
            return Ok(VirtualHost::with_bucket(base_domain, bucket));
        };
        if is_valid_domain(host) {
            let bucket = host.to_ascii_lowercase();
            return Ok(VirtualHost::with_bucket(host, bucket));
        }

        Err(s3_error!(InvalidRequest, "Invalid host header"))
    }
}

#[derive(Debug)]
pub struct MultiDomain {
    base_domains: Vec<String>,
}

impl MultiDomain {
    /// Create a new `MultiDomain` with the base domains.
    ///
    /// # Errors
    /// Returns an error if
    /// + any of the base domains are invalid.
    /// + any of the base domains overlap with each other.
    /// + no base domains are specified.
    pub fn new<I>(base_domains: I) -> Result<Self, DomainError>
    where
        I: IntoIterator,
        I::Item: AsRef<str>,
    {
        let mut v: Vec<String> = default();

        for domain in base_domains {
            let domain = domain.as_ref();

            if !is_valid_domain(domain) {
                return Err(DomainError::InvalidDomain);
            }

            for other in &v {
                if domain.ends_with(other) || other.ends_with(domain) {
                    return Err(DomainError::OverlappingSubdomains);
                }
            }

            v.push(domain.to_owned());
        }

        if v.is_empty() {
            return Err(DomainError::ZeroDomains);
        }

        Ok(Self { base_domains: v })
    }
}

impl S3Host for MultiDomain {
    fn parse_host_header<'a>(&'a self, host: &'a str) -> S3Result<VirtualHost<'a>> {
        for base_domain in &self.base_domains {
            if let Some(vh) = parse_host_header(base_domain, host) {
                return Ok(vh);
            }
        }

        if is_valid_domain(host) {
            let bucket = host.to_ascii_lowercase();
        Ok(VirtualHost::with_bucket(host, bucket))
            return Ok(VirtualHost::with_bucket(host, bucket));
        }

        Err(s3_error!(InvalidRequest, "Invalid host header"))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::S3ErrorCode;

    #[test]
    fn single_domain_new() {
        let domain = "example.com";
        let result = SingleDomain::new(domain);
        let sd = result.unwrap();
        assert_eq!(sd.base_domain, domain);

        let domain = "example.com.org";
        let result = SingleDomain::new(domain);
        let sd = result.unwrap();
        assert_eq!(sd.base_domain, domain);

        let domain = "example.com.";
        let result = SingleDomain::new(domain);
        let err = result.unwrap_err();
        assert!(matches!(err, DomainError::InvalidDomain));

        let domain = "example.com:";
        let result = SingleDomain::new(domain);
        let err = result.unwrap_err();
        assert!(matches!(err, DomainError::InvalidDomain));

        let domain = "example.com:80";
        let result = SingleDomain::new(domain);
        assert!(result.is_ok());
    }

    #[test]
    fn multi_domain_new() {
        let domains = ["example.com", "example.org"];
        let result = MultiDomain::new(&domains);
        let md = result.unwrap();
        assert_eq!(md.base_domains, domains);

        let domains = ["example.com", "example.com"];
        let result = MultiDomain::new(&domains);
        let err = result.unwrap_err();
        assert!(matches!(err, DomainError::OverlappingSubdomains));

        let domains = ["example.com", "example.com.org"];
        let result = MultiDomain::new(&domains);
        let md = result.unwrap();
        assert_eq!(md.base_domains, domains);

        let domains: [&str; 0] = [];
        let result = MultiDomain::new(&domains);
        let err = result.unwrap_err();
        assert!(matches!(err, DomainError::ZeroDomains));
    }

    #[test]
    fn multi_domain_parse() {
        let domains = ["example.com", "example.org"];
        let md = MultiDomain::new(domains.iter().copied()).unwrap();

        let host = "example.com";
        let result = md.parse_host_header(host);
        let vh = result.unwrap();
        assert_eq!(vh.domain(), host);
        assert_eq!(vh.bucket(), None);

        let host = "example.org";
        let result = md.parse_host_header(host);
        let vh = result.unwrap();
        assert_eq!(vh.domain(), host);
        assert_eq!(vh.bucket(), None);

        let host = "example.com.org";
        let result = md.parse_host_header(host);
        let vh = result.unwrap();
        assert_eq!(vh.domain(), host);
        assert_eq!(vh.bucket(), Some("example.com.org"));

        let host = "example.com.org.";
        let result = md.parse_host_header(host);
        let err = result.unwrap_err();
        assert!(matches!(err.code(), S3ErrorCode::InvalidRequest));

        let host = "example.com.org.example.com";
        let result = md.parse_host_header(host);
        let vh = result.unwrap();
        assert_eq!(vh.domain(), "example.com");
        assert_eq!(vh.bucket(), Some("example.com.org"));
    }
}
+4 −4
Original line number Diff line number Diff line
@@ -257,7 +257,7 @@ mod tests {
    #[test]
    fn virtual_hosted_style() {
        {
            let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com");
            let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com").unwrap();
            let host = "s3.us-east-1.amazonaws.com";
            let uri_path = "/example.com/homepage.html";
            let vh = s3_host.parse_host_header(host).unwrap();
@@ -267,7 +267,7 @@ mod tests {
        }

        {
            let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com");
            let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com").unwrap();
            let host = "doc-example-bucket1.eu.s3.eu-west-1.amazonaws.com";
            let uri_path = "/homepage.html";
            let vh = s3_host.parse_host_header(host).unwrap();
@@ -277,7 +277,7 @@ mod tests {
        }

        {
            let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com");
            let s3_host = SingleDomain::new("s3.eu-west-1.amazonaws.com").unwrap();
            let host = "doc-example-bucket1.eu.s3.eu-west-1.amazonaws.com";
            let uri_path = "/";
            let vh = s3_host.parse_host_header(host).unwrap();
@@ -287,7 +287,7 @@ mod tests {
        }

        {
            let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com");
            let s3_host = SingleDomain::new("s3.us-east-1.amazonaws.com").unwrap();
            let host = "example.com";
            let uri_path = "/homepage.html";
            let vh = s3_host.parse_host_header(host).unwrap();
Loading