Commit 6d8a5c3f authored by Nugine's avatar Nugine
Browse files

s3s: error: add headers field

parent a4aa6378
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ use std::fmt;
use std::io::Write;
use std::str::FromStr;

use hyper::HeaderMap;
use hyper::StatusCode;

pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
@@ -26,6 +27,7 @@ struct Inner {
    request_id: Option<String>,
    status_code: Option<StatusCode>,
    source: Option<StdError>,
    headers: Option<HeaderMap>,
}

impl S3Error {
@@ -38,6 +40,7 @@ impl S3Error {
            request_id: None,
            status_code: None,
            source: None,
            headers: None,
        }))
    }

@@ -75,6 +78,10 @@ impl S3Error {
        self.0.status_code = Some(val);
    }

    pub fn set_headers(&mut self, val: HeaderMap) {
        self.0.headers = Some(val);
    }

    #[must_use]
    pub fn code(&self) -> &S3ErrorCode {
        &self.0.code
@@ -100,6 +107,16 @@ impl S3Error {
        self.0.status_code.or_else(|| self.0.code.status_code())
    }

    #[must_use]
    pub fn headers(&self) -> Option<&HeaderMap> {
        self.0.headers.as_ref()
    }

    #[must_use]
    pub(crate) fn take_headers(&mut self) -> Option<HeaderMap> {
        self.0.headers.take()
    }

    #[must_use]
    pub fn internal_error<E>(source: E) -> Self
    where
+7 −4
Original line number Diff line number Diff line
@@ -58,11 +58,14 @@ fn build_s3_request<T>(input: T, req: &mut Request) -> S3Request<T> {
    }
}

fn serialize_error(x: S3Error) -> S3Result<Response> {
    let status = x.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
fn serialize_error(mut e: S3Error) -> S3Result<Response> {
    let status = e.status_code().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
    let mut res = Response::with_status(status);
    http::set_xml_body(&mut res, &x)?;
    drop(x);
    http::set_xml_body(&mut res, &e)?;
    if let Some(headers) = e.take_headers() {
        res.headers = headers;
    }
    drop(e);
    Ok(res)
}

+29 −0
Original line number Diff line number Diff line
@@ -33,3 +33,32 @@ fn track_future_size() {
        assert_eq!(size, expected, "{name:?} size changed: prev {expected}, now {size}");
    }
}

#[test]
fn error_custom_headers() {
    fn redirect307(location: &str) -> S3Error {
        let mut err = S3Error::new(S3ErrorCode::TemporaryRedirect);

        err.set_headers({
            let mut headers = HeaderMap::new();
            headers.insert(crate::header::LOCATION, location.parse().unwrap());
            headers
        });

        err
    }

    let res = serialize_error(redirect307("http://example.com")).unwrap();
    assert_eq!(res.status, StatusCode::TEMPORARY_REDIRECT);
    assert_eq!(res.headers.get("location").unwrap(), "http://example.com");

    let body = res.body.bytes().unwrap();
    let body = std::str::from_utf8(&body).unwrap();
    assert_eq!(
        body,
        concat!(
            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>",
            "<Error><Code>TemporaryRedirect</Code></Error>"
        )
    );
}