Unverified Commit a179d23b authored by rcshaw's avatar rcshaw Committed by GitHub
Browse files

feat(s3s): Add custom validation option to S3ServiceBuilder (#342)



* Add custom validation option to S3ServiceBuilder

* pr comments: fix error logging syntax, add more comments to trait, make DefaultNameValidation static

* remove unused import, update comment for syntax

* address PR comments: rename default validation, extract duplicate to var, remove static var

* reserve NameValidation constructor

* NameValidation requires 'static

---------

Co-authored-by: default avatarNugine <nugine@foxmail.com>
parent 0399dbd5
Loading
Loading
Loading
Loading
+121 −0
Original line number Diff line number Diff line
use s3s::auth::SimpleAuth;
use s3s::host::SingleDomain;
use s3s::service::S3ServiceBuilder;
use s3s::validation::NameValidation;
use s3s_fs::FileSystem;

use std::env;
@@ -84,6 +85,33 @@ async fn serial() -> MutexGuard<'static, ()> {
    LOCK.lock().await
}

fn create_client_with_validation(validation: impl NameValidation + 'static) -> Client {
    // Setup with custom validation
    let service = {
        let fs = FileSystem::new(FS_ROOT).unwrap();
        let mut b = S3ServiceBuilder::new(fs);
        let cred = Credentials::for_tests();
        b.set_auth(SimpleAuth::from_single(cred.access_key_id(), cred.secret_access_key()));
        b.set_host(SingleDomain::new(DOMAIN_NAME).unwrap());
        b.set_validation(validation);
        b.build()
    };

    // Convert to aws http client
    let client_inner = s3s_aws::Client::from(service);

    // Setup aws sdk config
    let cred = Credentials::for_tests();
    let config = SdkConfig::builder()
        .credentials_provider(SharedCredentialsProvider::new(cred))
        .http_client(client_inner)
        .region(Region::new(REGION))
        .endpoint_url(format!("http://{DOMAIN_NAME}"))
        .build();

    Client::new(&config)
}

async fn create_bucket(c: &Client, bucket: &str) -> Result<()> {
    let location = BucketLocationConstraint::from(REGION);
    let cfg = CreateBucketConfiguration::builder().location_constraint(location).build();
@@ -592,3 +620,96 @@ async fn test_single_object_get_range() -> Result<()> {

    Ok(())
}

#[tokio::test]
#[tracing::instrument]
async fn test_relaxed_bucket_validation() -> Result<()> {
    struct RelaxedNameValidation;

    impl NameValidation for RelaxedNameValidation {
        fn validate_bucket_name(&self, name: &str) -> bool {
            !name.is_empty()
        }
    }

    let _guard = serial().await;

    let c = create_client_with_validation(RelaxedNameValidation);

    // Test with bucket names that should pass with relaxed validation
    let relaxed_bucket_names = [
        "UPPERCASE-BUCKET",       // Uppercase not normally allowed
        "bucket_with_underscore", // Underscores not allowed
    ];

    for bucket_name in relaxed_bucket_names {
        let location = BucketLocationConstraint::from(REGION);
        let cfg = CreateBucketConfiguration::builder().location_constraint(location).build();

        let result = c
            .create_bucket()
            .create_bucket_configuration(cfg)
            .bucket(bucket_name)
            .send()
            .await;

        // Should not fail due to bucket name validation
        match result {
            Ok(_) => {
                debug!("Successfully created bucket with relaxed validation: {bucket_name}");

                // Verify the bucket was actually created by checking bucket existence
                let head_result = c.head_bucket().bucket(bucket_name).send().await;
                assert!(head_result.is_ok(), "Failed to head bucket {bucket_name} after creation");

                // Clean up the bucket
                let delete_result = delete_bucket(&c, bucket_name).await;
                assert!(delete_result.is_ok(), "Failed to delete bucket {bucket_name}");
            }
            Err(e) => {
                let error_str = format!("{e:?}");
                debug!("Bucket creation failed for other reasons (expected): {bucket_name} - {error_str}");
                // Verify it's not a bucket name validation error
                assert!(!error_str.contains("InvalidBucketName") && !error_str.contains("bucket name"));
            }
        }
    }

    Ok(())
}

#[tokio::test]
#[tracing::instrument]
async fn test_default_bucket_validation() -> Result<()> {
    let _guard = serial().await;

    let c = Client::new(config()); // Uses default validation

    // Test with invalid bucket names that should be rejected by AWS rules
    let invalid_bucket_names = [
        "UPPERCASE-BUCKET",       // Uppercase not allowed
        "bucket_with_underscore", // Underscores not allowed
        "bucket..double.dots",    // Consecutive dots not allowed
    ];

    for bucket_name in invalid_bucket_names {
        // Try to create bucket with invalid name - should fail with default validation
        let location = BucketLocationConstraint::from(REGION);
        let cfg = CreateBucketConfiguration::builder().location_constraint(location).build();

        let result = c
            .create_bucket()
            .create_bucket_configuration(cfg)
            .bucket(bucket_name)
            .send()
            .await;

        // Should fail due to bucket name validation
        assert!(result.is_err(), "Expected error for invalid bucket name: {bucket_name}");

        let error_str = format!("{:?}", result.unwrap_err());
        debug!("Default validation rejected bucket name {bucket_name}: {error_str}");
    }

    Ok(())
}
+1 −0
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ pub mod path;
pub mod route;
pub mod service;
pub mod stream;
pub mod validation;
pub mod xml;

pub use self::error::*;
+11 −2
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ use crate::route::S3Route;
use crate::s3_trait::S3;
use crate::stream::VecByteStream;
use crate::stream::aggregate_unlimited;
use crate::validation::{AwsNameValidation, NameValidation};

use std::mem;
use std::net::{IpAddr, SocketAddr};
@@ -60,6 +61,7 @@ pub struct CallContext<'a> {
    pub auth: Option<&'a dyn S3Auth>,
    pub access: Option<&'a dyn S3Access>,
    pub route: Option<&'a dyn S3Route>,
    pub validation: Option<&'a dyn NameValidation>,
}

fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
@@ -261,6 +263,9 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
        let vh;
        let vh_bucket;
        {
            let default_validation = &const { AwsNameValidation::new() };
            let validation = ccx.validation.unwrap_or(default_validation);

            let result = 'parse: {
                if let (Some(host_header), Some(s3_host)) = (host_header.as_deref(), ccx.host) {
                    if !is_socket_addr_or_ip_addr(host_header) {
@@ -270,13 +275,17 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
                        debug!(?vh);

                        vh_bucket = vh.bucket();
                        break 'parse crate::path::parse_virtual_hosted_style(vh_bucket, &decoded_uri_path);
                        break 'parse crate::path::parse_virtual_hosted_style_with_validation(
                            vh_bucket,
                            &decoded_uri_path,
                            validation,
                        );
                    }
                }

                debug!(?decoded_uri_path, "parsing path-style request");
                vh_bucket = None;
                crate::path::parse_path_style(&decoded_uri_path)
                crate::path::parse_path_style_with_validation(&decoded_uri_path, validation)
            };

            req.s3ext.s3_path = Some(result.map_err(|err| convert_parse_s3_path_error(&err))?);
+98 −3
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
//! + [Request styles](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAPI.html#virtual-hosted-path-style-requests)
//! + [Bucket naming rules](https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html)

use crate::validation::{AwsNameValidation, NameValidation};
use std::net::IpAddr;

/// A path in the S3 storage
@@ -159,6 +160,13 @@ pub const fn check_key(key: &str) -> bool {
/// # Errors
/// Returns an `Err` if the s3 path is invalid
pub fn parse_path_style(uri_path: &str) -> Result<S3Path, ParseS3PathError> {
    parse_path_style_with_validation(uri_path, &AwsNameValidation::new())
}

/// Parses a path-style request with custom validation
/// # Errors
/// Returns an `Err` if the s3 path is invalid
pub fn parse_path_style_with_validation(uri_path: &str, validation: &dyn NameValidation) -> Result<S3Path, ParseS3PathError> {
    let Some(path) = uri_path.strip_prefix('/') else { return Err(ParseS3PathError::InvalidPath) };

    if path.is_empty() {
@@ -171,7 +179,7 @@ pub fn parse_path_style(uri_path: &str) -> Result<S3Path, ParseS3PathError> {
        Some((bucket, key)) => (bucket, Some(key)),
    };

    if !check_bucket_name(bucket) {
    if !validation.validate_bucket_name(bucket) {
        return Err(ParseS3PathError::InvalidBucketName);
    }

@@ -188,11 +196,22 @@ pub fn parse_path_style(uri_path: &str) -> Result<S3Path, ParseS3PathError> {
/// # Errors
/// Returns an `Err` if the s3 path is invalid
pub fn parse_virtual_hosted_style(vh_bucket: Option<&str>, uri_path: &str) -> Result<S3Path, ParseS3PathError> {
    let Some(bucket) = vh_bucket else { return parse_path_style(uri_path) };
    parse_virtual_hosted_style_with_validation(vh_bucket, uri_path, &AwsNameValidation::new())
}

/// Parses a virtual-hosted-style request with custom validation
/// # Errors
/// Returns an `Err` if the s3 path is invalid
pub fn parse_virtual_hosted_style_with_validation(
    vh_bucket: Option<&str>,
    uri_path: &str,
    validation: &dyn NameValidation,
) -> Result<S3Path, ParseS3PathError> {
    let Some(bucket) = vh_bucket else { return parse_path_style_with_validation(uri_path, validation) };

    let Some(key) = uri_path.strip_prefix('/') else { return Err(ParseS3PathError::InvalidPath) };

    if !check_bucket_name(bucket) {
    if !validation.validate_bucket_name(bucket) {
        return Err(ParseS3PathError::InvalidBucketName);
    }

@@ -215,6 +234,8 @@ mod tests {
    use super::*;

    use crate::host::{S3Host, SingleDomain};
    use crate::validation::AwsNameValidation;
    use crate::validation::tests::RelaxedNameValidation;

    #[test]
    fn bucket_naming_rules() {
@@ -297,4 +318,78 @@ mod tests {
            assert_eq!(ans, expected);
        }
    }

    #[test]
    fn test_path_style_with_custom_validation() {
        // Test invalid bucket names that should pass with relaxed validation
        let invalid_names = [
            "UPPERCASE",              // uppercase not allowed in AWS
            "bucket_with_underscore", // underscores not allowed
            "bucket..double.dots",    // consecutive dots not allowed
            "bucket-",                // ending with hyphen not allowed
            "192.168.1.1",            // IP address format not allowed
        ];

        for bucket_name in invalid_names {
            let path = format!("/{bucket_name}/key");

            // Should fail with default validation
            let result = parse_path_style_with_validation(&path, &AwsNameValidation::new());
            assert!(result.is_err(), "Expected error for bucket name: {bucket_name}");

            // Should pass with relaxed validation
            let result = parse_path_style_with_validation(&path, &RelaxedNameValidation::new());
            assert!(result.is_ok(), "Expected success for bucket name: {bucket_name}");

            if let Ok(S3Path::Object { bucket, key }) = result {
                assert_eq!(bucket.as_ref(), bucket_name);
                assert_eq!(key.as_ref(), "key");
            }
        }

        // Test that valid names still work
        let result = parse_path_style_with_validation("/valid-bucket/key", &RelaxedNameValidation::new());
        assert!(result.is_ok());
    }

    #[test]
    fn test_virtual_hosted_style_with_custom_validation() {
        // Test invalid bucket names that should pass with relaxed validation
        let invalid_names = ["UPPERCASE", "bucket_with_underscore", "bucket..double.dots"];

        for bucket_name in invalid_names {
            // Should fail with default validation
            let result = parse_virtual_hosted_style_with_validation(Some(bucket_name), "/key", &AwsNameValidation::new());
            assert!(result.is_err(), "Expected error for bucket name: {bucket_name}");

            // Should pass with relaxed validation
            let result = parse_virtual_hosted_style_with_validation(Some(bucket_name), "/key", &RelaxedNameValidation::new());
            assert!(result.is_ok(), "Expected success for bucket name: {bucket_name}");

            if let Ok(S3Path::Object { bucket, key }) = result {
                assert_eq!(bucket.as_ref(), bucket_name);
                assert_eq!(key.as_ref(), "key");
            }
        }
    }

    #[test]
    fn test_path_style_validation_fallback() {
        // Test that parse_path_style uses AwsNameValidation
        let result1 = parse_path_style("/UPPERCASE/key");
        let result2 = parse_path_style_with_validation("/UPPERCASE/key", &AwsNameValidation::new());

        // Both should give the same result (error for invalid bucket name)
        assert_eq!(result1.is_err(), result2.is_err());
    }

    #[test]
    fn test_virtual_hosted_style_validation_fallback() {
        // Test that parse_virtual_hosted_style uses AwsNameValidation
        let result1 = parse_virtual_hosted_style(Some("UPPERCASE"), "/key");
        let result2 = parse_virtual_hosted_style_with_validation(Some("UPPERCASE"), "/key", &AwsNameValidation::new());

        // Both should give the same result (error for invalid bucket name)
        assert_eq!(result1.is_err(), result2.is_err());
    }
}
+45 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ use crate::host::S3Host;
use crate::http::{Body, Request};
use crate::route::S3Route;
use crate::s3_trait::S3;
use crate::validation::NameValidation;
use crate::{HttpError, HttpRequest, HttpResponse};

use std::fmt;
@@ -18,6 +19,7 @@ pub struct S3ServiceBuilder {
    auth: Option<Box<dyn S3Auth>>,
    access: Option<Box<dyn S3Access>>,
    route: Option<Box<dyn S3Route>>,
    validation: Option<Box<dyn NameValidation>>,
}

impl S3ServiceBuilder {
@@ -29,6 +31,7 @@ impl S3ServiceBuilder {
            auth: None,
            access: None,
            route: None,
            validation: None,
        }
    }

@@ -48,6 +51,10 @@ impl S3ServiceBuilder {
        self.route = Some(Box::new(route));
    }

    pub fn set_validation(&mut self, validation: impl NameValidation) {
        self.validation = Some(Box::new(validation));
    }

    #[must_use]
    pub fn build(self) -> S3Service {
        S3Service {
@@ -57,6 +64,7 @@ impl S3ServiceBuilder {
                auth: self.auth,
                access: self.access,
                route: self.route,
                validation: self.validation,
            }),
        }
    }
@@ -73,6 +81,7 @@ struct Inner {
    auth: Option<Box<dyn S3Auth>>,
    access: Option<Box<dyn S3Access>>,
    route: Option<Box<dyn S3Route>>,
    validation: Option<Box<dyn NameValidation>>,
}

impl S3Service {
@@ -95,6 +104,7 @@ impl S3Service {
            auth: self.inner.auth.as_deref(),
            access: self.inner.access.as_deref(),
            route: self.inner.route.as_deref(),
            validation: self.inner.validation.as_deref(),
        };
        let result = match crate::ops::call(&mut req, &ccx).await {
            Ok(resp) => Ok(HttpResponse::from(resp)),
@@ -203,4 +213,39 @@ mod tests {
        assert!(output_size(&S3Service::call) <= 2900);
        assert!(output_size(&S3Service::call_owned) <= 3200);
    }

    // Test validation functionality
    use crate::validation::NameValidation;

    // Mock S3 implementation for testing
    struct MockS3;
    impl S3 for MockS3 {}

    // Test validation that allows any bucket name
    struct RelaxedValidation;
    impl NameValidation for RelaxedValidation {
        fn validate_bucket_name(&self, _name: &str) -> bool {
            true // Allow any bucket name
        }
    }

    #[test]
    fn test_service_builder_validation() {
        let validation = RelaxedValidation;
        let mut builder = S3ServiceBuilder::new(MockS3);
        builder.set_validation(validation);
        let service = builder.build();

        // Verify validation was set
        assert!(service.inner.validation.is_some());
    }

    #[test]
    fn test_service_builder_default_validation() {
        let builder = S3ServiceBuilder::new(MockS3);
        let service = builder.build();

        // Should have default validation when none is set
        assert!(service.inner.validation.is_none()); // None means it will use AwsNameValidation
    }
}
Loading