Unverified Commit f31981cf authored by Nugine's avatar Nugine
Browse files

s3s: http: request, response

parent 24a10f95
Loading
Loading
Loading
Loading
+11 −16
Original line number Diff line number Diff line
@@ -119,6 +119,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {

fn status_code_name(code: u16) -> &'static str {
    match code {
        200 => "OK",
        204 => "NO_CONTENT",
        _ => unimplemented!(),
    }
@@ -220,7 +221,7 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
                g.ln("let mut res = http::Response::default();");

                let code_name = status_code_name(op.http_code);
                g.ln(f!("*res.status_mut() = http::StatusCode::{code_name};"));
                g.ln(f!("res.status = http::StatusCode::{code_name};"));

                g.ln("res");
            }
@@ -230,13 +231,9 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
        rust::Type::Struct(ty) => {
            if ty.fields.is_empty() {
                g.ln(f!("pub fn serialize_http(_: {output}) -> S3Result<http::Response> {{"));
                if op.http_code == 200 {
                    g.ln("Ok(http::Response::default())");
                } else {
                    g.ln("let mut res = http::Response::default();");
                {
                    let code_name = status_code_name(op.http_code);
                    g.ln(f!("*res.status_mut() = http::StatusCode::{code_name};"));
                    g.ln("Ok(res)");
                    g.ln(f!("Ok(http::Response::with_status(http::StatusCode::{code_name}))"));
                }
                g.ln("}");
            } else {
@@ -247,11 +244,9 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
                    assert!(["header", "metadata", "xml", "payload"].contains(&field.position.as_str()),);
                }

                g.ln("let mut res = http::Response::default();");

                if op.http_code != 200 {
                {
                    let code_name = status_code_name(op.http_code);
                    g.ln(f!("*res.status_mut() = http::StatusCode::{code_name};"));
                    g.ln(f!("let mut res = http::Response::with_status(http::StatusCode::{code_name});"));
                }

                if is_xml_output(ty) {
@@ -261,7 +256,7 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
                        "Policy" => {
                            assert!(field.option_type);
                            g.ln(f!("if let Some(val) = x.{} {{", field.name));
                            g.ln("*res.body_mut() = http::Body::from(val);");
                            g.ln("res.body = http::Body::from(val);");
                            g.ln("}");
                        }
                        "StreamingBlob" => {
@@ -342,7 +337,7 @@ fn codegen_op_http_de(op: &Operation, rust_types: &RustTypes, g: &mut Codegen) {

                if op.name == "PutObject" {
                    // POST object
                    g.ln("if let Some(m) = req.extensions_mut().remove::<http::Multipart>() {");
                    g.ln("if let Some(m) = req.extensions.remove::<http::Multipart>() {");
                    g.ln("    return Self::deserialize_http_multipart(req, m);");
                    g.ln("}");
                    g.lf();
@@ -512,7 +507,7 @@ fn codegen_op_http_de_multipart(op: &Operation, rust_types: &RustTypes, g: &mut
        "let bucket = http::unwrap_bucket(req);",
        "let key = http::parse_field_value(&m, \"key\")?.ok_or_else(|| invalid_request!(\"missing key\"))?;",
        "",
        "let vec_stream = req.extensions_mut().remove::<crate::stream::VecByteStream>().expect(\"missing vec stream\");",
        "let vec_stream = req.extensions.remove::<crate::stream::VecByteStream>().expect(\"missing vec stream\");",
        "",
        "let content_length = i64::try_from(vec_stream.exact_remaining_length()).map_err(|e|s3_error!(e, InvalidArgument, \"content-length overflow\"))?;",
        "",
@@ -779,7 +774,7 @@ fn codegen_router(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
        }
    };

    g.ln("match req.method().clone() {");
    g.ln("match req.method {");
    for &method in &methods {
        g.ln(f!("hyper::Method::{method} => match s3_path {{"));

@@ -901,7 +896,7 @@ fn codegen_router(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                                if cond.is_empty().not() {
                                    cond.push_str(" && ");
                                }
                                cond.push_str(&f!("req.headers().contains_key(\"{h}\")"));
                                cond.push_str(&f!("req.headers.contains_key(\"{h}\")"));
                            }

                            if qs.is_empty().not() {
+14 −14
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ where
    T: TryFromHeaderValue,
    T::Error: std::error::Error + Send + Sync + 'static,
{
    let mut iter = req.headers().get_all(name).into_iter();
    let mut iter = req.headers.get_all(name).into_iter();
    let Some(val) = iter.next() else { return Err(missing_header(name)) };
    let None = iter.next() else { return Err(duplicate_header(name)) } ;

@@ -44,7 +44,7 @@ where
    T: TryFromHeaderValue,
    T::Error: std::error::Error + Send + Sync + 'static,
{
    let mut iter = req.headers().get_all(name).into_iter();
    let mut iter = req.headers.get_all(name).into_iter();
    let Some(val) = iter.next() else { return Ok(None) };
    let None = iter.next() else { return Err(duplicate_header(name)) } ;

@@ -55,7 +55,7 @@ where
}

pub fn parse_opt_header_timestamp(req: &Request, name: &HeaderName, fmt: TimestampFormat) -> S3Result<Option<Timestamp>> {
    let mut iter = req.headers().get_all(name).into_iter();
    let mut iter = req.headers.get_all(name).into_iter();
    let Some(val) = iter.next() else { return Ok(None) };
    let None = iter.next() else { return Err(duplicate_header(name)) } ;

@@ -72,7 +72,7 @@ where
    T::Error: std::error::Error + Send + Sync + 'static,
{
    let mut list = List::new();
    for val in req.headers().get_all(name) {
    for val in req.headers.get_all(name) {
        let ans = T::try_from_header_value(val).map_err(|err| invalid_header(err, name, val))?;
        list.push(ans)
    }
@@ -98,7 +98,7 @@ pub fn parse_query<T: FromStr>(req: &Request, name: &str) -> S3Result<T>
where
    T::Err: std::error::Error + Send + Sync + 'static,
{
    let Some(qs) = req.extensions().get::<OrderedQs>() else { return Err(missing_query(name)) };
    let Some(qs) = req.extensions.get::<OrderedQs>() else { return Err(missing_query(name)) };

    let mut iter = qs.get_all(name);
    let Some(val) = iter.next() else { return Err(missing_query(name)) };
@@ -111,7 +111,7 @@ pub fn parse_opt_query<T: FromStr>(req: &Request, name: &str) -> S3Result<Option
where
    T::Err: std::error::Error + Send + Sync + 'static,
{
    let Some(qs) = req.extensions().get::<OrderedQs>() else { return Ok(None) };
    let Some(qs) = req.extensions.get::<OrderedQs>() else { return Ok(None) };

    let mut iter = qs.get_all(name);
    let Some(val) = iter.next() else { return Ok(None) };
@@ -121,7 +121,7 @@ where
}

pub fn parse_opt_query_timestamp(req: &Request, name: &str, fmt: TimestampFormat) -> S3Result<Option<Timestamp>> {
    let Some(qs) = req.extensions().get::<OrderedQs>() else { return Ok(None) };
    let Some(qs) = req.extensions.get::<OrderedQs>() else { return Ok(None) };

    let mut iter = qs.get_all(name);
    let Some(val) = iter.next() else { return Ok(None) };
@@ -132,13 +132,13 @@ pub fn parse_opt_query_timestamp(req: &Request, name: &str, fmt: TimestampFormat

#[track_caller]
pub fn unwrap_bucket(req: &mut Request) -> String {
    let Some(S3Path::Bucket { bucket }) = req.extensions_mut().remove::<S3Path>() else { panic!("url parameter not found") };
    let Some(S3Path::Bucket { bucket }) = req.extensions.remove::<S3Path>() else { panic!("url parameter not found") };
    bucket.into()
}

#[track_caller]
pub fn unwrap_object(req: &mut Request) -> (String, String) {
    let Some(S3Path::Object { bucket, key }) = req.extensions_mut().remove::<S3Path>() else { panic!("url parameter not found") };
    let Some(S3Path::Object { bucket, key }) = req.extensions.remove::<S3Path>() else { panic!("url parameter not found") };
    (bucket.into(), key.into())
}

@@ -160,7 +160,7 @@ pub fn take_xml_body<T>(req: &mut Request) -> S3Result<T>
where
    T: for<'xml> xml::Deserialize<'xml>,
{
    let bytes = req.body_mut().bytes().expect("full body not found");
    let bytes = req.body.bytes().expect("full body not found");
    if bytes.is_empty() {
        return Err(S3ErrorCode::MissingRequestBodyError.into());
    }
@@ -175,7 +175,7 @@ pub fn take_opt_xml_body<T>(req: &mut Request) -> S3Result<Option<T>>
where
    T: for<'xml> xml::Deserialize<'xml>,
{
    let bytes = req.body_mut().bytes().expect("full body not found");
    let bytes = req.body.bytes().expect("full body not found");
    if bytes.is_empty() {
        return Ok(None);
    }
@@ -187,7 +187,7 @@ where
}

pub fn take_string_body(req: &mut Request) -> S3Result<String> {
    let bytes = req.body_mut().bytes().expect("full body not found");
    let bytes = req.body.bytes().expect("full body not found");
    match from_utf8_vec(bytes.into()) {
        Some(s) => Ok(s),
        None => Err(invalid_request!("expected UTF-8 body")),
@@ -195,7 +195,7 @@ pub fn take_string_body(req: &mut Request) -> S3Result<String> {
}

pub fn take_stream_body(req: &mut Request) -> StreamingBlob {
    let body = std::mem::take(req.body_mut());
    let body = std::mem::take(&mut req.body);
    let size_hint = http_body::Body::size_hint(&body);
    debug!(?size_hint, "taking streaming blob");
    StreamingBlob::from(body)
@@ -203,7 +203,7 @@ pub fn take_stream_body(req: &mut Request) -> StreamingBlob {

pub fn parse_opt_metadata(req: &Request) -> S3Result<Option<Metadata>> {
    let mut metadata = Metadata::default();
    let map = req.headers();
    let map = &req.headers;
    for name in map.keys() {
        let Some(key) = name.as_str().strip_prefix("x-amz-meta-") else { continue };
        if key.is_empty() {
+5 −2
Original line number Diff line number Diff line
@@ -19,8 +19,11 @@ pub use self::multipart::*;
mod body;
pub use self::body::*;

pub type Request = hyper::Request<Body>;
pub type Response = hyper::Response<Body>;
mod request;
pub use self::request::Request;

mod response;
pub use self::response::Response;

pub use hyper::header::{HeaderName, HeaderValue, InvalidHeaderValue};
pub use hyper::http::StatusCode;
+28 −0
Original line number Diff line number Diff line
use super::Body;

use hyper::http::Extensions;
use hyper::http::HeaderValue;
use hyper::HeaderMap;
use hyper::Method;
use hyper::Uri;

pub struct Request {
    pub method: Method,
    pub uri: Uri,
    pub headers: HeaderMap<HeaderValue>,
    pub extensions: Extensions,
    pub body: Body,
}

impl From<hyper::Request<Body>> for Request {
    fn from(req: hyper::Request<Body>) -> Self {
        let (parts, body) = req.into_parts();
        Self {
            method: parts.method,
            uri: parts.uri,
            headers: parts.headers,
            extensions: parts.extensions,
            body,
        }
    }
}
+32 −0
Original line number Diff line number Diff line
use super::Body;

use hyper::http::HeaderValue;
use hyper::HeaderMap;
use hyper::StatusCode;

#[derive(Default)]
pub struct Response {
    pub status: StatusCode,
    pub headers: HeaderMap<HeaderValue>,
    pub body: Body,
}

impl From<Response> for hyper::Response<Body> {
    fn from(res: Response) -> Self {
        let mut ans = hyper::Response::default();
        *ans.status_mut() = res.status;
        *ans.headers_mut() = res.headers;
        *ans.body_mut() = res.body;
        ans
    }
}

impl Response {
    #[must_use]
    pub fn with_status(status: StatusCode) -> Self {
        Self {
            status,
            ..Default::default()
        }
    }
}
Loading