Unverified Commit 140e2747 authored by 0xdx2's avatar 0xdx2 Committed by GitHub
Browse files

feat: add xml_attr field and related functionality for XML serialization (#299)



* feat: add xml_attr field and related functionality for XML serialization

* feat: add xml_namespace support for XML serialization and deserialization

* fix: simplify attribute handling and enhance documentation for peek_event and consume_peeked methods

* revert unnecessary change

* flatten xml_namespace

* enhance test

* fix: improve XML deserialization by adding namespace support and refactoring attribute handling

* refactor: change DeEvent visibility to private and update peek_event method to private

* revert unnecessary changes

* fix xml deserializer

* revert unnecessary change

* fix

* fix

---------

Co-authored-by: default avatardamon <damonxue2@gmail.com>
Co-authored-by: default avatarNugine <nugine@foxmail.com>
parent b86c15cf
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -208,6 +208,11 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
                        http_query: field.traits.http_query().map(o),
                        xml_name: field.traits.xml_name().map(o),
                        xml_flattened: field.traits.xml_flattened(),

                        is_xml_attr: field.traits.xml_attr(),
                        xml_namespace_uri: field.traits.xml_namespace_uri().map(o),
                        xml_namespace_prefix: field.traits.xml_namespace_prefix().map(o),

                        is_custom_extension: field.traits.minio(),
                    };
                    fields.push(field);
@@ -317,6 +322,9 @@ fn patch_types(space: &mut RustTypes) {
            http_query: None,
            xml_name: Some(request.name.clone()),
            xml_flattened: false,
            is_xml_attr: false,
            xml_namespace_uri: None,
            xml_namespace_prefix: None,
            is_custom_extension: false,
        });
        ty.name = o("SelectObjectContentInput");
+4 −0
Original line number Diff line number Diff line
@@ -98,6 +98,10 @@ pub struct StructField {
    pub xml_name: Option<String>,
    pub xml_flattened: bool,

    pub is_xml_attr: bool,
    pub xml_namespace_uri: Option<String>,
    pub xml_namespace_prefix: Option<String>,

    pub is_custom_extension: bool,
}

+106 −13
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ use super::rust::default_value_literal;

use crate::declare_codegen;
use crate::v1::ops::is_op_output;
use crate::v1::rust::StructField;

use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::ops::Not;
@@ -31,7 +32,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
    for (&ty_name, &xml_name) in &root_type_names {
        match xml_name {
            Some(xml_name) if xml_name != ty_name => {
                if can_impl_serialize(ty_name) {
                if can_impl_serialize(rust_types, ty_name) {
                    g!("//   Serialize: {ty_name} {xml_name:?}");
                }
                if can_impl_deserialize(rust_types, ty_name) {
@@ -39,7 +40,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
                }
            }
            _ => {
                if can_impl_serialize(ty_name) {
                if can_impl_serialize(rust_types, ty_name) {
                    g!("//   Serialize: {ty_name}");
                }
                if can_impl_deserialize(rust_types, ty_name) {
@@ -50,7 +51,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
    }
    g!();
    for ty_name in &field_type_names {
        if can_impl_serialize_content(ty_name) {
        if can_impl_serialize_content(rust_types, ty_name) {
            g!("//   SerializeContent: {ty_name}");
        }
        if can_impl_deserialize_content(rust_types, ty_name) {
@@ -166,14 +167,14 @@ fn collect_xml_types<'a>(

const SPECIAL_TYPES: &[&str] = &["AssumeRoleOutput"];

fn can_impl_serialize(ty_name: &str) -> bool {
fn can_impl_serialize(rust_types: &RustTypes, ty_name: &str) -> bool {
    if SPECIAL_TYPES.contains(&ty_name) {
        return false;
    }
    can_impl_serialize_content(ty_name)
    can_impl_serialize_content(rust_types, ty_name)
}

fn can_impl_serialize_content(_ty_name: &str) -> bool {
fn can_impl_serialize_content(_rust_types: &RustTypes, _ty_name: &str) -> bool {
    true
}

@@ -193,6 +194,9 @@ fn can_impl_deserialize_content(rust_types: &RustTypes, ty_name: &str) -> bool {
                if matches!(field.position.as_str(), "header" | "query" | "metadata") {
                    return false;
                }
                if field.is_xml_attr {
                    return false;
                }
            }
        }
        rust::Type::Alias(_) => {}
@@ -223,7 +227,7 @@ fn codegen_xml_serde(ops: &Operations, rust_types: &RustTypes, root_type_names:
        // https://github.com/Nugine/s3s/issues/2
        let xml_name = xml_name.or(ty.xml_name.as_deref()).unwrap_or(&ty.name);

        if can_impl_serialize(&ty.name) {
        if can_impl_serialize(rust_types, &ty.name) {
            g!("impl Serialize for {} {{", ty.name);
            g!("fn serialize<W: Write>(&self, s: &mut Serializer<W>) -> SerResult {{");

@@ -316,7 +320,7 @@ fn codegen_xml_serde_content(ops: &Operations, rust_types: &RustTypes, field_typ

#[allow(clippy::too_many_lines)]
fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, ty: &rust::Struct) {
    if can_impl_serialize_content(&ty.name) {
    if can_impl_serialize_content(rust_types, &ty.name) {
        g!("impl SerializeContent for {} {{", ty.name);
        g!(
            "fn serialize_content<W: Write>(&self, {}: &mut Serializer<W>) -> SerResult {{",
@@ -351,9 +355,21 @@ fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, t
                    g!("s.timestamp(\"{}\", &self.{}, TimestampFormat::{})?;", xml_name, field.name, fmt);
                }
            } else if field.option_type {
                // Check if field has xml_namespace trait (needs attributes)
                if let (Some(uri), Some(prefix)) = (&field.xml_namespace_uri, &field.xml_namespace_prefix) {
                    assert_eq!(prefix, "xsi");
                    g!("if let Some(ref val) = self.{} {{", field.name);
                    g!("let attrs = [");
                    g!("(\"xmlns:{}\", \"{}\"),", prefix, uri);
                    g!("(\"{}:type\", val.type_.as_str()),", prefix);
                    g!("];");
                    g!("s.content_with_attrs(\"{}\", &attrs, val)?;", xml_name);
                    g!("}}");
                } else {
                    g!("if let Some(ref val) = self.{} {{", field.name);
                    g!("s.content(\"{xml_name}\", val)?;");
                    g!("}}");
                }
            } else {
                let default_is_zero = match field.default_value.as_ref() {
                    Some(v) => v.as_u64() == Some(0),
@@ -366,6 +382,9 @@ fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, t
                    g!("s.content(\"{}\", &self.{})?;", xml_name, field.name);
                    g!("}}");
                } else {
                    if field.is_xml_attr {
                        continue; // skip xml attribute fields
                    }
                    g!("s.content(\"{}\", &self.{})?;", xml_name, field.name);
                }
            }
@@ -384,6 +403,14 @@ fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, t
            if ty.fields.is_empty() { "_" } else { "d" },
        );

        let mut xml_ns_field: Option<&StructField> = None;
        for field in &ty.fields {
            if field.xml_namespace_prefix.is_some() {
                assert!(xml_ns_field.is_none());
                xml_ns_field = Some(field);
            }
        }

        for field in &ty.fields {
            if field.position == "sealed" {
                continue;
@@ -392,11 +419,18 @@ fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, t
        }

        if ty.fields.is_empty().not() {
            if xml_ns_field.is_some() {
                g!("d.for_each_element_with_start(|d, x, start| match x {{");
            } else {
                g!("d.for_each_element(|d, x| match x {{");
            }
            for field in &ty.fields {
                if field.position == "sealed" {
                    continue;
                }
                if field.is_xml_attr {
                    continue;
                }

                let xml_name = field.xml_name.as_ref().unwrap_or(&field.camel_name);
                let field_name = field.name.as_str();
@@ -422,7 +456,66 @@ fn codegen_xml_serde_content_struct(_ops: &Operations, rust_types: &RustTypes, t
                    g!("{field_name} = Some(d.timestamp(TimestampFormat::{fmt})?);");
                } else {
                    g!("if {field_name}.is_some() {{ return Err(DeError::DuplicateField); }}");

                    if let Some(xml_ns_field) = xml_ns_field.filter(|x| x.name == field.name) {
                        let rust::Type::Struct(xml_ns_ty) = &rust_types[xml_ns_field.type_.as_str()] else { panic!() };
                        let mut xml_attr_field: Option<&StructField> = None;
                        for field in &xml_ns_ty.fields {
                            if field.is_xml_attr {
                                assert!(xml_attr_field.is_none());
                                xml_attr_field = Some(field);
                            }
                        }
                        let xml_attr_field = xml_attr_field.unwrap();

                        let xml_attr_name = &xml_attr_field.name;

                        g!("let mut {}: Option<{}> = None;", xml_attr_name, xml_attr_field.type_);
                        g!("for attr in start.attributes() {{");
                        g!("  let Ok(attr) = attr else {{ return Err(DeError::InvalidAttribute) }};");
                        g!("  if attr.key.as_ref() == b\"{}\" {{", xml_attr_field.xml_name.as_deref().unwrap());
                        g!(
                            "  {} = Some(attr.unescape_value().map_err(DeError::InvalidXml)?.into_owned().into());",
                            xml_attr_name
                        );
                        g!("  }}");
                        g!("}}");

                        for field in &xml_ns_ty.fields {
                            if field.is_xml_attr {
                                continue;
                            }
                            g!("let mut {}: Option<{}> = None;", field.name, field.type_);
                        }

                        g!("d.for_each_element(|d, x| match x {{");
                        for field in &xml_ns_ty.fields {
                            if field.is_xml_attr {
                                continue;
                            }
                            let xml_name = field.xml_name.as_ref().unwrap_or(&field.camel_name);
                            g!("b\"{xml_name}\" => {{");
                            let field_name = field.name.as_str();
                            g!("    if {field_name}.is_some() {{ return Err(DeError::DuplicateField); }}");
                            g!("        {field_name} = Some(d.content()?);");
                            g!("    Ok(())");
                            g!("}}");
                        }
                        g!("_ => Err(DeError::UnexpectedTagName),");
                        g!("}})?;");

                        g!("{field_name} = Some({} {{", field.type_);
                        for field in &xml_ns_ty.fields {
                            if field.is_xml_attr {
                                g!("{0}: {0}.ok_or(DeError::MissingField)?,", field.name);
                                continue;
                            }
                            g!("{},", field.name);
                        }
                        g!("}});");
                    } else {
                        g!("{field_name} = Some(d.content()?);");
                    }
                }

                g!("Ok(())");
+20 −0
Original line number Diff line number Diff line
@@ -229,11 +229,31 @@ impl Traits {
        self.get("smithy.api#xmlName")?.as_str()
    }

    #[must_use]
    pub fn xml_attr(&self) -> bool {
        self.get("smithy.api#xmlAttribute").is_some()
    }

    #[must_use]
    pub fn xml_flattened(&self) -> bool {
        self.get("smithy.api#xmlFlattened").is_some()
    }

    #[must_use]
    fn xml_namespace(&self) -> Option<&Map<String, Value>> {
        self.get("smithy.api#xmlNamespace")?.as_object()
    }

    #[must_use]
    pub fn xml_namespace_uri(&self) -> Option<&str> {
        self.xml_namespace()?.get("uri")?.as_str()
    }

    #[must_use]
    pub fn xml_namespace_prefix(&self) -> Option<&str> {
        self.xml_namespace()?.get("prefix")?.as_str()
    }

    #[must_use]
    pub fn s3_unwrapped_xml_output(&self) -> bool {
        self.get("aws.customizations#s3UnwrappedXmlOutput").is_some()
+33 −2
Original line number Diff line number Diff line
@@ -67,6 +67,12 @@ pub enum DeError {
    #[error("unexpected tag name")]
    UnexpectedTagName,

    #[error("invalid attribute")]
    InvalidAttribute,

    #[error("unexpected attribute name")]
    UnexpectedAttributeName,

    /// Invalid content
    #[error("invalid content")]
    InvalidContent,
@@ -240,8 +246,33 @@ impl<'xml> Deserializer<'xml> {
                    self.consume_peeked();

                    let name = start.name();
                    f(self, name.as_ref())?;
                    self.expect_end(name.as_ref())?;
                    let name = name.as_ref();
                    f(self, name)?;
                    self.expect_end(name)?;

                    continue;
                }
                DeEvent::Text(_) => {
                    self.consume_peeked();
                    continue;
                }
                DeEvent::End(_) | DeEvent::Eof => {
                    return Ok(());
                }
            }
        }
    }

    pub fn for_each_element_with_start(&mut self, mut f: impl FnMut(&mut Self, &[u8], &BytesStart<'_>) -> DeResult) -> DeResult {
        loop {
            match self.peek_event()? {
                DeEvent::Start(start) => {
                    self.consume_peeked();

                    let name = start.name();
                    let name = name.as_ref();
                    f(self, name, &start)?;
                    self.expect_end(name)?;

                    continue;
                }
Loading