Commit 09fae30c authored by Nugine's avatar Nugine
Browse files

feat(s3s/ops): forward region & service to S3Request

parent 5c64b476
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@ pub(crate) struct S3Extensions {
    pub vec_stream: Option<VecByteStream>,

    pub credentials: Option<Credentials>,
    pub region: Option<String>,
    pub service: Option<String>,
}

impl From<hyper::Request<Body>> for Request {
+27 −9
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ mod get_object;
mod tests;

use crate::access::{S3Access, S3AccessContext};
use crate::auth::S3Auth;
use crate::auth::{Credentials, S3Auth};
use crate::error::*;
use crate::header;
use crate::host::S3Host;
@@ -55,18 +55,23 @@ pub struct CallContext<'a> {
}

fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
    let credentials = req.s3ext.credentials.take();
    let extensions = mem::take(&mut req.extensions);
    let headers = mem::take(&mut req.headers);
    let method = req.method.clone();
    let uri = mem::take(&mut req.uri);
    let headers = mem::take(&mut req.headers);
    let extensions = mem::take(&mut req.extensions);
    let credentials = req.s3ext.credentials.take();
    let region = req.s3ext.region.take();
    let service = req.s3ext.service.take();

    S3Request {
        method,
        uri,
        headers,
        input,
        credentials,
        extensions,
        headers,
        uri,
        method: req.method.clone(),
        credentials,
        region,
        service,
    }
}

@@ -297,7 +302,20 @@ async fn prepare(req: &mut Request, ccx: &CallContext<'_>) -> S3Result<Prepare>
            transformed_body = scx.transformed_body;

            req.s3ext.multipart = scx.multipart;
            req.s3ext.credentials = credentials;

            match credentials {
                Some(cred) => {
                    req.s3ext.credentials = Some(Credentials {
                        access_key: cred.access_key,
                        secret_key: cred.secret_key,
                    });
                    req.s3ext.region = cred.region;
                    req.s3ext.service = cred.service;
                }
                None => {
                    req.s3ext.credentials = None;
                }
            }
        }

        if body_changed {
+44 −18
Original line number Diff line number Diff line
use crate::auth::Credentials;
use crate::auth::S3Auth;
use crate::auth::SecretKey;
use crate::error::*;
use crate::http;
use crate::http::{AwsChunkedStream, Body, Multipart};
@@ -72,12 +72,19 @@ pub struct SignatureContext<'a> {
    pub multipart: Option<Multipart>,
}

pub struct CredentialsExt {
    pub access_key: String,
    pub secret_key: SecretKey,
    pub region: Option<String>,
    pub service: Option<String>,
}

fn require_auth(auth: Option<&dyn S3Auth>) -> S3Result<&dyn S3Auth> {
    auth.ok_or_else(|| s3_error!(NotImplemented, "This service has no authentication provider"))
}

impl SignatureContext<'_> {
    pub async fn check(&mut self) -> S3Result<Option<Credentials>> {
    pub async fn check(&mut self) -> S3Result<Option<CredentialsExt>> {
        if let Some(result) = self.v2_check().await {
            debug!("checked signature v2");
            return Ok(Some(result?));
@@ -92,7 +99,7 @@ impl SignatureContext<'_> {
    }

    #[tracing::instrument(skip(self))]
    pub async fn v4_check(&mut self) -> Option<S3Result<Credentials>> {
    pub async fn v4_check(&mut self) -> Option<S3Result<CredentialsExt>> {
        // POST auth
        if self.req_method == Method::POST {
            if let Some(ref mime) = self.mime {
@@ -120,7 +127,7 @@ impl SignatureContext<'_> {
        None
    }

    pub async fn v4_check_post_signature(&mut self) -> S3Result<Credentials> {
    pub async fn v4_check_post_signature(&mut self) -> S3Result<CredentialsExt> {
        let auth = require_auth(self.auth)?;

        let multipart = {
@@ -157,9 +164,11 @@ impl SignatureContext<'_> {
        let access_key = credential.access_key_id.to_owned();
        let secret_key = auth.get_secret_key(&access_key).await?;

        let region = credential.aws_region;
        let service = credential.aws_service;

        let string_to_sign = info.policy;
        let signature =
            sig_v4::calculate_signature(string_to_sign, &secret_key, &amz_date, credential.aws_region, credential.aws_service);
        let signature = sig_v4::calculate_signature(string_to_sign, &secret_key, &amz_date, region, service);

        let expected_signature = info.x_amz_signature;
        if signature != expected_signature {
@@ -167,11 +176,19 @@ impl SignatureContext<'_> {
            return Err(s3_error!(SignatureDoesNotMatch));
        }

        let region = region.to_owned();
        let service = service.to_owned();

        self.multipart = Some(multipart);
        Ok(Credentials { access_key, secret_key })
        Ok(CredentialsExt {
            access_key,
            secret_key,
            region: Some(region),
            service: Some(service),
        })
    }

    pub async fn v4_check_presigned_url(&mut self) -> S3Result<Credentials> {
    pub async fn v4_check_presigned_url(&mut self) -> S3Result<CredentialsExt> {
        let qs = self.qs.unwrap(); // assume: qs has "X-Amz-Signature"

        let presigned_url = PresignedUrlV4::parse(qs).map_err(|err| invalid_request!(err, "missing presigned url v4 fields"))?;
@@ -217,6 +234,9 @@ impl SignatureContext<'_> {
        let access_key = presigned_url.credential.access_key_id;
        let secret_key = auth.get_secret_key(access_key).await?;

        let region = presigned_url.credential.aws_region;
        let service = presigned_url.credential.aws_service;

        let signature = {
            let headers = self.hs.find_multiple(&presigned_url.signed_headers);
            let method = &self.req_method;
@@ -224,8 +244,6 @@ impl SignatureContext<'_> {

            let canonical_request = sig_v4::create_presigned_canonical_request(method, uri_path, qs.as_ref(), &headers);

            let region = presigned_url.credential.aws_region;
            let service = presigned_url.credential.aws_service;
            let amz_date = &presigned_url.amz_date;
            let string_to_sign = sig_v4::create_string_to_sign(&canonical_request, amz_date, region, service);

@@ -238,14 +256,16 @@ impl SignatureContext<'_> {
            return Err(s3_error!(SignatureDoesNotMatch));
        }

        Ok(Credentials {
        Ok(CredentialsExt {
            access_key: access_key.into(),
            secret_key,
            region: Some(region.into()),
            service: Some(service.into()),
        })
    }

    #[tracing::instrument(skip(self))]
    pub async fn v4_check_header_auth(&mut self) -> S3Result<Credentials> {
    pub async fn v4_check_header_auth(&mut self) -> S3Result<CredentialsExt> {
        let authorization: AuthorizationV4<'_> = {
            // assume: headers has "authorization"
            let mut a = extract_authorization_v4(&self.hs)?.unwrap();
@@ -344,14 +364,16 @@ impl SignatureContext<'_> {
            self.transformed_body = Some(Body::from(stream.into_byte_stream()));
        }

        Ok(Credentials {
        Ok(CredentialsExt {
            access_key: access_key.into(),
            secret_key,
            region: Some(region.into()),
            service: Some(service.into()),
        })
    }

    #[tracing::instrument(skip(self))]
    pub async fn v2_check(&mut self) -> Option<S3Result<Credentials>> {
    pub async fn v2_check(&mut self) -> Option<S3Result<CredentialsExt>> {
        if let Some(qs) = self.qs {
            if qs.has("Signature") {
                debug!("checking presigned url");
@@ -369,7 +391,7 @@ impl SignatureContext<'_> {
        None
    }

    pub async fn v2_check_header_auth(&mut self, auth_v2: AuthorizationV2<'_>) -> S3Result<Credentials> {
    pub async fn v2_check_header_auth(&mut self, auth_v2: AuthorizationV2<'_>) -> S3Result<CredentialsExt> {
        let method = &self.req_method;

        let date = self.hs.get_unique("date").or_else(|| self.hs.get_unique("x-amz-date"));
@@ -399,13 +421,15 @@ impl SignatureContext<'_> {
            return Err(s3_error!(SignatureDoesNotMatch));
        }

        Ok(Credentials {
        Ok(CredentialsExt {
            access_key: access_key.into(),
            secret_key,
            region: None,
            service: Some("s3".into()),
        })
    }

    pub async fn v2_check_presigned_url(&mut self) -> S3Result<Credentials> {
    pub async fn v2_check_presigned_url(&mut self) -> S3Result<CredentialsExt> {
        let qs = self.qs.unwrap(); // assume: qs has "Signature"
        let presigned_url = PresignedUrlV2::parse(qs).map_err(|err| invalid_request!(err, "missing presigned url v2 fields"))?;

@@ -433,9 +457,11 @@ impl SignatureContext<'_> {
            return Err(s3_error!(SignatureDoesNotMatch));
        }

        Ok(Credentials {
        Ok(CredentialsExt {
            access_key: access_key.into(),
            secret_key,
            region: None,
            service: Some("s3".into()),
        })
    }
}
+29 −19
Original line number Diff line number Diff line
@@ -7,49 +7,59 @@ use stdx::default::default;
#[derive(Debug)]
#[non_exhaustive]
pub struct S3Request<T> {
    /// HTTP method
    pub method: Method,

    // Raw URI
    pub uri: Uri,

    // Headers
    pub headers: HeaderMap<HeaderValue>,

    /// Operation input
    pub input: T,

    /// Identity information.
    ///
    /// `None` means anonymous request.
    pub credentials: Option<Credentials>,

    /// Request extensions
    ///
    /// It is used to pass custom data between middlewares.
    pub extensions: Extensions,

    // Headers
    pub headers: HeaderMap<HeaderValue>,
    /// Identity information.
    ///
    /// `None` means anonymous request.
    pub credentials: Option<Credentials>,

    // Raw URI
    pub uri: Uri,
    /// The requested region.
    pub region: Option<String>,

    /// HTTP method
    pub method: Method,
    /// The requested service.
    pub service: Option<String>,
}

impl<T> S3Request<T> {
    pub fn new(input: T) -> Self {
        Self {
            method: default(),
            uri: default(),
            headers: default(),
            input,
            credentials: default(),
            extensions: default(),
            headers: default(),
            uri: default(),
            method: default(),
            credentials: default(),
            region: default(),
            service: default(),
        }
    }

    pub fn map_input<U>(self, f: impl FnOnce(T) -> U) -> S3Request<U> {
        S3Request {
            method: self.method,
            uri: self.uri,
            headers: self.headers,
            input: f(self.input),
            credentials: self.credentials,
            extensions: self.extensions,
            headers: self.headers,
            uri: self.uri,
            method: self.method,
            credentials: self.credentials,
            region: self.region,
            service: self.service,
        }
    }
}