Commit 3127a1c3 authored by Nugine's avatar Nugine
Browse files

feat(s3s/xml): add more serde

parent f1f3ac0c
Loading
Loading
Loading
Loading
+264 −292
Original line number Diff line number Diff line
@@ -22,8 +22,35 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
        ""
    ];

    codegen_xml_ser(ops, rust_types);
    codegen_xml_de(ops, rust_types);
    let (root_type_names, field_type_names) = collect_xml_types(ops, rust_types);

    for (&ty_name, &xml_name) in &root_type_names {
        match xml_name {
            Some(xml_name) if xml_name != ty_name => {
                g!("//   Serialize: {ty_name} {xml_name:?}");
                if can_impl_deserialize(rust_types, ty_name) {
                    g!("// Deserialize: {ty_name} {xml_name:?}");
                }
            }
            _ => {
                g!("//   Serialize: {ty_name}");
                if can_impl_deserialize(rust_types, ty_name) {
                    g!("// Deserialize: {ty_name}");
                }
            }
        }
    }
    g!();
    for ty_name in &field_type_names {
        g!("//   SerializeContent: {ty_name}");
        if can_impl_deserialize(rust_types, ty_name) {
            g!("// DeserializeContent: {ty_name}");
        }
    }
    g!();

    codegen_xml_serde(ops, rust_types, &root_type_names);
    codegen_xml_serde_content(ops, rust_types, &field_type_names);
}

pub fn is_xml_payload(field: &rust::StructField) -> bool {
@@ -35,29 +62,32 @@ pub fn is_xml_output(ty: &rust::Struct) -> bool {
    ty.xml_name.is_some() || ty.fields.iter().any(|field| field.position == "xml")
}

#[allow(clippy::too_many_lines)]
fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes) {
    let mut root_type_names: BTreeSet<&str> = default();
fn collect_xml_types<'a>(
    ops: &'a Operations,
    rust_types: &'a RustTypes,
) -> (BTreeMap<&'a str, Option<&'a str>>, BTreeSet<&'a str>) {
    let mut root_type_names: BTreeMap<&str, Option<&str>> = default();
    let mut field_type_names: BTreeSet<&str> = default();

    let mut q: VecDeque<&str> = default();

    for op in ops.values() {
        let ty_name = op.output.as_str();

        for ty_name in [op.input.as_str(), op.output.as_str()] {
            let rust_type = &rust_types[ty_name];
            let rust::Type::Struct(ty) = rust_type else { panic!() };

            if is_xml_output(ty) {
            root_type_names.insert(ty_name);
                root_type_names.insert(ty_name, None);
                field_type_names.insert(ty_name);
                q.push_back(ty_name);
            } else {
                assert!(ty.xml_name.is_none());
            }

            let mut payload_count = 0;
            for field in &ty.fields {
                if is_xml_payload(field) {
                root_type_names.insert(&field.type_);
                    root_type_names.insert(&field.type_, field.xml_name.as_deref());
                    field_type_names.insert(&field.type_);
                    q.push_back(&field.type_);
                    payload_count += 1;
@@ -65,13 +95,14 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes) {
            }
            assert!(payload_count <= 1);
        }
    }

    {
        let extra = ["Progress", "Stats"];
        for ty in extra {
            root_type_names.insert(ty);
            field_type_names.insert(ty);
            q.push_back(ty);
        for ty_name in extra {
            root_type_names.insert(ty_name, None);
            field_type_names.insert(ty_name);
            q.push_back(ty_name);
        }
    }

@@ -117,9 +148,137 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes) {
        }
    }

    for rust_type in field_type_names.iter().map(|&name| &rust_types[name]) {
    (root_type_names, field_type_names)
}

fn can_impl_deserialize(rust_types: &RustTypes, ty_name: &str) -> bool {
    let rust_type = &rust_types[ty_name];
    match rust_type {
        rust::Type::Struct(ty) => {
            for field in &ty.fields {
                if matches!(field.position.as_str(), "header" | "query" | "metadata") {
                    return false;
                }
            }
        }
        rust::Type::Alias(_) => {}
        rust::Type::List(_) => {}
        rust::Type::Map(_) => {}
        rust::Type::Provided(_) => {}
        rust::Type::StrEnum(_) => {}
        rust::Type::StructEnum(_) => {}
        rust::Type::Timestamp(_) => {}
    }
    true
}

fn s3_unwrapped_xml_output(ops: &Operations, ty_name: &str) -> bool {
    ops.iter().any(|(_, op)| op.s3_unwrapped_xml_output && op.output == ty_name)
}

fn codegen_xml_serde(ops: &Operations, rust_types: &RustTypes, root_type_names: &BTreeMap<&str, Option<&str>>) {
    for (rust_type, xml_name) in root_type_names.iter().map(|(&name, xml_name)| (&rust_types[name], xml_name)) {
        let rust::Type::Struct(ty) = rust_type else { panic!("{rust_type:#?}") };

        // https://github.com/Nugine/s3s/pull/127
        if s3_unwrapped_xml_output(ops, &ty.name) {
            assert_eq!(ty.name, "GetBucketLocationOutput");
            continue; // manually implemented
        }

        // https://github.com/Nugine/s3s/issues/2
        let xml_name = xml_name.or(ty.xml_name.as_deref()).unwrap_or(&ty.name);

        {
            g!("impl Serialize for {} {{", ty.name);
            g!("fn serialize<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

            g!("s.content(\"{xml_name}\", self)");

            g!("}}");
            g!("}}");

            g!();
        }

        if can_impl_deserialize(rust_types, &ty.name) {
            g!("impl<'xml> Deserialize<'xml> for {} {{", ty.name);
            g!("fn deserialize(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");

            g!("d.named_element(\"{xml_name}\", Deserializer::content)");

            g!("}}");
            g!("}}");
            g!();
        }
    }
}

fn codegen_xml_serde_content(ops: &Operations, rust_types: &RustTypes, field_type_names: &BTreeSet<&str>) {
    for rust_type in field_type_names.iter().map(|&name| &rust_types[name]) {
        match rust_type {
            rust::Type::Alias(_) => {}
            rust::Type::Provided(_) => {}
            rust::Type::Timestamp(_) => {}
            rust::Type::List(_) => panic!(),
            rust::Type::Map(_) => panic!(),
            rust::Type::StrEnum(ty) => {
                {
                    g!("impl SerializeContent for {} {{", ty.name);
                    g!("fn serialize_content<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

                    g!("self.as_str().serialize_content(s)");

                    g!("}}");
                    g!("}}");
                }
                {
                    g!("impl<'xml> DeserializeContent<'xml> for {} {{", ty.name);
                    g!("fn deserialize_content(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");
                    g!("String::deserialize_content(d).map(Self::from)");
                    g!("}}");
                    g!("}}");
                }
            }
            rust::Type::StructEnum(ty) => {
                {
                    g!("impl SerializeContent for {} {{", ty.name);
                    g!("fn serialize_content<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

                    g!("match self {{");

                    for variant in &ty.variants {
                        g!("Self::{0}(x) => s.content(\"{0}\", x),", variant.name);
                    }

                    g!("}}");

                    g!("}}");
                    g!("}}");
                }
                {
                    g!("impl<'xml> DeserializeContent<'xml> for {} {{", ty.name);
                    g!("fn deserialize_content(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");

                    g!("d.element(|d, x| match x {{");
                    for variant in &ty.variants {
                        g!("b\"{0}\" => Ok(Self::{0}(d.content()?)),", variant.name);
                    }
                    g!("_ => Err(DeError::UnexpectedTagName)");
                    g!("}})");

                    g!("}}");
                    g!("}}");
                }
            }
            rust::Type::Struct(ty) => codegen_xml_serde_content_struct(ops, rust_types, ty),
        }
    }
}

#[allow(clippy::too_many_lines)]
fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, ty: &rust::Struct) {
    {
        g!("impl SerializeContent for {} {{", ty.name);
        g!(
            "fn serialize_content<W: Write>(&self, {}: &mut Serializer<W>) -> SerResult {{",
@@ -154,15 +313,8 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes) {
                    g!("s.timestamp(\"{}\", &self.{}, TimestampFormat::{})?;", xml_name, field.name, fmt);
                }
            } else if field.option_type {
                        let s3_unwrapped_xml_output =
                            ops.iter().any(|(_, op)| op.s3_unwrapped_xml_output && op.output == ty.name);

                g!("if let Some(ref val) = self.{} {{", field.name);
                        if s3_unwrapped_xml_output {
                            g!("val.serialize_content(s)?;");
                        } else {
                g!("s.content(\"{xml_name}\", val)?;");
                        }
                g!("}}");
            } else {
                let default_is_zero = match field.default_value.as_ref() {
@@ -185,147 +337,15 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes) {

        g!("}}");
        g!("}}");
            }
            rust::Type::StrEnum(ty) => {
                g!("impl SerializeContent for {} {{", ty.name);
                g!("fn serialize_content<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

                g!("self.as_str().serialize_content(s)");

                g!("}}");
                g!("}}");
            }
            rust::Type::StructEnum(ty) => {
                g!("impl SerializeContent for {} {{", ty.name);
                g!("fn serialize_content<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

                g!("match self {{");

                for variant in &ty.variants {
                    g!("Self::{0}(x) => s.content(\"{0}\", x),", variant.name);
                }

                g!("}}");

                g!("}}");
                g!("}}");
            }
            rust::Type::Alias(_) => {}
            rust::Type::Provided(_) => {}
            rust::Type::Timestamp(_) => {}
            rust::Type::List(_) => panic!(),
            rust::Type::Map(_) => panic!(),
        }
        g!();
    }

    for rust_type in root_type_names.iter().map(|&name| &rust_types[name]) {
        let rust::Type::Struct(ty) = rust_type else { panic!("{rust_type:#?}") };

        g!("impl Serialize for {} {{", ty.name);
        g!("fn serialize<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

        let xml_name = ty.xml_name.as_deref().unwrap_or(ty.name.as_str());
        g!("s.content(\"{xml_name}\", self)");

        g!("}}");
        g!("}}");

        g!();
    }
}

#[allow(clippy::too_many_lines)]
fn codegen_xml_de(ops: &Operations, rust_types: &RustTypes) {
    let mut root_type_names: BTreeMap<&str, Option<&str>> = default();
    let mut field_type_names: BTreeSet<&str> = default();

    let mut q: VecDeque<&str> = default();

    for op in ops.values() {
        let ty_name = op.input.as_str();

        let rust_type = &rust_types[ty_name];
        let rust::Type::Struct(ty) = rust_type else { panic!() };
        assert!(ty.xml_name.is_none());

        let mut payload_count = 0;
        for field in &ty.fields {
            if is_xml_payload(field) {
                root_type_names.insert(&field.type_, field.xml_name.as_deref());
                field_type_names.insert(&field.type_);
                q.push_back(&field.type_);
                payload_count += 1;
            }
        }
        assert!(payload_count <= 1);
    }

    {
        let extra = ["Progress", "Stats"];
        for ty in extra {
            root_type_names.insert(ty, None);
            field_type_names.insert(ty);
            q.push_back(ty);
        }
    }

    while let Some(name) = q.pop_front() {
        let rust_type = &rust_types[name];
        match rust_type {
            rust::Type::Struct(ty) => {
                for field in &ty.fields {
                    let is_xml_field = field.position == "xml" || is_xml_payload(field);
                    if is_xml_field.not() {
                        continue;
                    }

                    let field_type = &rust_types[field.type_.as_str()];

                    if let rust::Type::List(list_ty) = field_type {
                        field_type_names.insert(list_ty.member.type_.as_str());
                        q.push_back(list_ty.member.type_.as_str());
                    } else {
                        field_type_names.insert(field.type_.as_str());
                        q.push_back(field.type_.as_str());
                    }
                }
            }
            rust::Type::Alias(_) => {}
            rust::Type::List(ty) => {
                field_type_names.insert(ty.member.type_.as_str());
                q.push_back(ty.member.type_.as_str());
            }
            rust::Type::StrEnum(ty) => {
                field_type_names.insert(ty.name.as_str());
            }
            rust::Type::StructEnum(ty) => {
                for variant in &ty.variants {
                    field_type_names.insert(variant.type_.as_str());
                    q.push_back(variant.type_.as_str());
                }
            }
            rust::Type::Provided(ty) => {
                assert!(matches!(ty.name.as_str(), "Event"));
            }
            rust::Type::Map(_) => unimplemented!(),
            rust::Type::Timestamp(_) => {}
        }
    }

    for rust_type in field_type_names.iter().map(|&name| &rust_types[name]) {
        match rust_type {
            rust::Type::Struct(ty) => {
    if can_impl_deserialize(rust_types, &ty.name) {
        g!("impl<'xml> DeserializeContent<'xml> for {} {{", ty.name);
        g!(
            "fn deserialize_content({}: &mut Deserializer<'xml>) -> DeResult<Self> {{",
            if ty.fields.is_empty() { "_" } else { "d" },
        );

                for field in &ty.fields {
                    assert!(field.position == "xml");
                }

        for field in &ty.fields {
            g!("let mut {}: Option<{}> = None;", field.name, field.type_);
        }
@@ -386,52 +406,4 @@ fn codegen_xml_de(ops: &Operations, rust_types: &RustTypes) {
        g!("}}");
        g!("}}");
    }
            rust::Type::StrEnum(ty) => {
                g!("impl<'xml> DeserializeContent<'xml> for {} {{", ty.name);
                g!("fn deserialize_content(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");
                g!("String::deserialize_content(d).map(Self::from)");
                g!("}}");
                g!("}}");
            }
            rust::Type::StructEnum(ty) => {
                g!("impl<'xml> DeserializeContent<'xml> for {} {{", ty.name);
                g!("fn deserialize_content(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");

                g!("d.element(|d, x| match x {{");
                for variant in &ty.variants {
                    g!("b\"{0}\" => Ok(Self::{0}(d.content()?)),", variant.name);
                }
                g!("_ => Err(DeError::UnexpectedTagName)");
                g!("}})");

                g!("}}");
                g!("}}");
            }

            rust::Type::Alias(_) => {}
            rust::Type::Provided(ty) => {
                assert!(matches!(ty.name.as_str(), "Event"));
            }
            rust::Type::List(_) => panic!(),
            rust::Type::Map(_) => panic!(),
            rust::Type::Timestamp(_) => {}
        }
        g!();
    }

    for (rust_type, xml_name) in root_type_names.iter().map(|(&name, xml_name)| (&rust_types[name], xml_name)) {
        let rust::Type::Struct(ty) = rust_type else { panic!("{rust_type:#?}") };

        g!("impl<'xml> Deserialize<'xml> for {} {{", ty.name);
        g!("fn deserialize(d: &mut Deserializer<'xml>) -> DeResult<Self> {{");

        assert!(ty.xml_name.is_none()); // canary for <https://github.com/Nugine/s3s/issues/2>

        let xml_name = xml_name.or(ty.xml_name.as_deref()).unwrap_or(&ty.name);
        g!("d.named_element(\"{xml_name}\", Deserializer::content)");

        g!("}}");
        g!("}}");
        g!();
    }
}
+5877 −3327

File changed.

Preview size limit exceeded, changes collapsed.

+38 −0
Original line number Diff line number Diff line
@@ -10,3 +10,41 @@ mod generated;

#[cfg(test)]
mod tests;

mod manually {
    use super::*;

    use crate::dto::BucketLocationConstraint;
    use crate::dto::GetBucketLocationOutput;

    impl Serialize for GetBucketLocationOutput {
        fn serialize<W: std::io::Write>(&self, s: &mut Serializer<W>) -> SerResult {
            if let Some(location_constraint) = &self.location_constraint {
                s.content("LocationConstraint", location_constraint)?;
            } else {
                s.content("LocationConstraint", "")?;
            }
            Ok(())
        }
    }

    impl<'xml> Deserialize<'xml> for GetBucketLocationOutput {
        fn deserialize(d: &mut Deserializer<'xml>) -> DeResult<Self> {
            let mut location_constraint: Option<BucketLocationConstraint> = None;
            d.for_each_element(|d, x| match x {
                b"LocationConstraint" => {
                    if location_constraint.is_some() {
                        return Err(DeError::DuplicateField);
                    }
                    let val: BucketLocationConstraint = d.content()?;
                    if !val.as_str().is_empty() {
                        location_constraint = Some(val);
                    }
                    Ok(())
                }
                _ => Err(DeError::UnexpectedTagName),
            })?;
            Ok(Self { location_constraint })
        }
    }
}
+48 −0
Original line number Diff line number Diff line
use crate::xml;

use std::fmt;

fn deserialize_content<T>(input: &[u8]) -> xml::DeResult<T>
where
    T: for<'xml> xml::DeserializeContent<'xml>,
{
    let mut d = xml::Deserializer::new(input);
    let ans = T::deserialize_content(&mut d)?;
    d.expect_eof()?;
    Ok(ans)
}

fn deserialize<T>(input: &[u8]) -> xml::DeResult<T>
where
    T: for<'xml> xml::Deserialize<'xml>,
@@ -28,6 +40,28 @@ fn serialize<T: xml::Serialize>(val: &T) -> xml::SerResult<String> {
    Ok(String::from_utf8(buf).unwrap())
}

fn test_serde<T>(val: &T)
where
    T: for<'xml> xml::Deserialize<'xml>,
    T: xml::Serialize,
    T: fmt::Debug + PartialEq,
{
    let xml = serialize(val).unwrap();
    let ans = deserialize::<T>(xml.as_bytes()).unwrap();
    assert_eq!(*val, ans);
}

fn test_serde_content<T>(val: &T)
where
    T: for<'xml> xml::DeserializeContent<'xml>,
    T: xml::SerializeContent,
    T: fmt::Debug + PartialEq,
{
    let xml = serialize_content(val).unwrap();
    let ans = deserialize_content::<T>(xml.as_bytes()).unwrap();
    assert_eq!(*val, ans);
}

/// See <https://github.com/Nugine/s3s/issues/2>
#[test]
fn d001() {
@@ -61,6 +95,8 @@ fn d001() {

    assert_eq!(parts[2].part_number, Some(3));
    assert_eq!(parts[2].e_tag.as_deref(), Some("\"acbd18db4cc2f85cedef654fccc4a4d8\""));

    test_serde(&ans);
}

#[test]
@@ -95,6 +131,8 @@ fn d002() {
        let csv = ans.input_serialization.csv.as_ref().unwrap();
        assert_eq!(csv.allow_quoted_record_delimiter, Some(false));
    }

    test_serde(&ans);
}

#[test]
@@ -116,6 +154,8 @@ fn d003() {
    let tag = &ans.tag_set[0];
    assert_eq!(tag.key, "Key4");
    assert_eq!(tag.value, "Value4");

    test_serde(&ans);
}

#[test]
@@ -140,6 +180,8 @@ fn d004() {

    assert!(ans.input_serialization.csv.is_some());
    assert!(ans.output_serialization.csv.is_some());

    test_serde(&ans);
}

#[test]
@@ -154,6 +196,8 @@ fn s001() {
    let expected = "<Days>365</Days>";

    assert_eq!(ans, expected);

    test_serde_content(&val);
}

#[test]
@@ -168,6 +212,8 @@ fn s002() {
        let expected = "<LocationConstraint>us-west-2</LocationConstraint>";

        assert_eq!(ans, expected);

        test_serde(&val);
    }
    {
        let val = crate::dto::GetBucketLocationOutput {
@@ -178,5 +224,7 @@ fn s002() {
        let expected = "<LocationConstraint></LocationConstraint>";

        assert_eq!(ans, expected);

        test_serde(&val);
    }
}