Commit 14ec3e56 authored by Nugine's avatar Nugine
Browse files

s3s: ops: override resp headers

parent e221fbc9
Loading
Loading
Loading
Loading
+21 −6
Original line number Diff line number Diff line
@@ -604,16 +604,31 @@ fn codegen_op_http_call(op: &Operation) {
    let method = op.name.to_snake_case();

    g!("let input = Self::deserialize_http(req)?;");
    g!("let req = super::build_s3_request(input, req);");
    g!("let result = s3.{method}(req).await;");
    g!("let s3_req = super::build_s3_request(input, req);");

    if op.name == "GetObject" {
        g!("let overrided_headers = super::get_object::extract_overrided_response_headers(&s3_req)?;");
    }

    g!("let result = s3.{method}(s3_req).await;");

    glines![
        "match result {",
        "    Ok(resp) => super::serialize_s3_response(resp, Self::serialize_http),",
        "    Err(err) => super::serialize_error(err),",
        "}",
        "let s3_resp = match result {",
        "    Ok(val) => val,",
        "    Err(err) => return super::serialize_error(err),",
        "};",
    ];

    g!("let mut resp = Self::serialize_http(s3_resp.output)?;");

    if op.name == "GetObject" {
        g!("resp.headers.extend(overrided_headers);");
    }

    g!("resp.headers.extend(s3_resp.headers);");
    g!("resp.extensions.extend(s3_resp.extensions);");
    g!("Ok(resp)");

    g!("}}");
    g!("}}");
}
+932 −558

File changed.

Preview size limit exceeded, changes collapsed.

+44 −0
Original line number Diff line number Diff line
use crate::dto::GetObjectInput;
use crate::dto::Timestamp;
use crate::dto::TimestampFormat;
use crate::header;
use crate::utils;
use crate::S3Request;
use crate::S3Result;

use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::HeaderMap;

use rust_utils::default::default;

pub fn extract_overrided_response_headers(req: &S3Request<GetObjectInput>) -> S3Result<HeaderMap<HeaderValue>> {
    let mut map: HeaderMap<HeaderValue> = default();

    add(&mut map, header::CONTENT_TYPE, req.input.response_content_type.as_deref())?;
    add(&mut map, header::CONTENT_LANGUAGE, req.input.response_content_language.as_deref())?;
    add_ts(&mut map, header::EXPIRES, req.input.response_expires.as_ref())?;
    add(&mut map, header::CACHE_CONTROL, req.input.response_cache_control.as_deref())?;
    add(&mut map, header::CONTENT_DISPOSITION, req.input.response_content_disposition.as_deref())?;
    add(&mut map, header::CONTENT_ENCODING, req.input.response_content_encoding.as_deref())?;

    Ok(map)
}

fn add(map: &mut HeaderMap<HeaderValue>, name: HeaderName, value: Option<&str>) -> S3Result<()> {
    let error = |e| invalid_request!(e, "invalid overrided header: {name}: {value:?}");
    if let Some(value) = value {
        let value = value.parse().map_err(error)?;
        map.insert(name, value);
    }
    Ok(())
}

fn add_ts(map: &mut HeaderMap<HeaderValue>, name: HeaderName, value: Option<&Timestamp>) -> S3Result<()> {
    let error = |e| invalid_request!(e, "invalid overrided header: {name}: {value:?}");
    if let Some(value) = value {
        let value = utils::fmt_timestamp(value, TimestampFormat::HttpDate, HeaderValue::from_bytes).map_err(error)?;
        map.insert(name, value);
    }
    Ok(())
}
+2 −14
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ pub use self::generated::*;
mod signature;
use self::signature::SignatureContext;

mod get_object;

#[cfg(test)]
mod tests;

@@ -16,7 +18,6 @@ use crate::http::{OrderedHeaders, OrderedQs};
use crate::http::{Request, Response};
use crate::path::{ParseS3PathError, S3Path};
use crate::request::S3Request;
use crate::response::S3Response;
use crate::s3_trait::S3;
use crate::stream::aggregate_unlimited;
use crate::stream::VecByteStream;
@@ -56,19 +57,6 @@ fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
    }
}

fn serialize_s3_response<T>(resp: S3Response<T>, f: impl FnOnce(T) -> S3Result<http::Response>) -> S3Result<http::Response> {
    let S3Response {
        output,
        headers,
        extensions,
    } = resp;

    let mut ans = f(output)?;
    ans.headers.extend(headers);
    ans.extensions.extend(extensions);
    Ok(ans)
}

fn serialize_error(x: S3Error) -> S3Result<Response> {
    let status = x.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
    let mut res = Response::with_status(status);