Unverified Commit e2f003cf authored by Nugine's avatar Nugine
Browse files

codegen: strenum

parent eac3a2be
Loading
Loading
Loading
Loading
+16 −13
Original line number Diff line number Diff line
use heck::ToUpperCamelCase;

use crate::dto::RustTypes;
use crate::f;
use crate::gen::Codegen;
@@ -20,7 +22,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
            rust::Type::Timestamp(_) => continue,
            rust::Type::List(_) => continue,
            rust::Type::Map(_) => continue,
            rust::Type::UnitEnum(_) => {}
            rust::Type::StrEnum(_) => {}
            rust::Type::Struct(_) => {}
            rust::Type::StructEnum(_) => {}
        }
@@ -76,12 +78,18 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                }
                g.ln("})");
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("{aws_name}::{0} => Self::{0},", variant.name));
                    let s3s_variant_name = variant.name.as_str();
                    let aws_variant_name = match s3s_variant_name {
                        "CRC32C" => "Crc32C".to_owned(),
                        _ => s3s_variant_name.to_upper_camel_case(),
                    };
                    g.ln(f!("{aws_name}::{aws_variant_name} => Self::from_static(Self::{s3s_variant_name}),"));
                }
                g.ln("_ => unreachable!(),");
                g.ln(f!("{aws_name}::Unknown(_) => Self::from(x.as_str().to_owned()),"));
                g.ln("_ => Self::from(x.as_str().to_owned()),");
                g.ln("})");
            }
            rust::Type::StructEnum(ty) => {
@@ -89,7 +97,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                for variant in &ty.variants {
                    g.ln(f!("{aws_name}::{0}(v) => Self::{0}(try_from_aws(v)?),", variant.name));
                }
                g.ln("_ => unreachable!(),");
                g.ln(f!("_ => unimplemented!(\"unknown variant of {aws_name}: {{x:?}}\"),"));
                g.ln("})");
            }
            _ => panic!(),
@@ -131,20 +139,15 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                    g.ln("Ok(y.build())");
                }
            }
            rust::Type::UnitEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("Self::{0} => {aws_name}::{0},", variant.name));
                }
                g.ln("_ => unreachable!(),");
                g.ln("})");
            rust::Type::StrEnum(_) => {
                g.ln(f!("Ok({aws_name}::from(x.as_str()))"));
            }
            rust::Type::StructEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("Self::{0}(v) => {aws_name}::{0}(try_into_aws(v)?),", variant.name));
                }
                g.ln("_ => unreachable!(),");
                g.ln(f!("_ => unimplemented!(\"unknown variant of {}: {{x:?}}\"),", ty.name));
                g.ln("})");
            }
            _ => panic!(),
+33 −37
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ use crate::{rust, smithy};
use std::collections::BTreeMap;
use std::ops::Not;

use heck::{ToSnakeCase, ToUpperCamelCase};
use heck::{ToShoutySnakeCase, ToSnakeCase};
use serde_json::Value;

pub fn to_type_name(shape_name: &str) -> &str {
@@ -101,21 +101,21 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
                let mut variants = Vec::new();
                for (variant_name, variant) in &shape.members {
                    let name = match variant_name.as_str() {
                        "CRC32C" => o("Crc32C"),
                        _ => variant_name.to_upper_camel_case(),
                        "CRC32C" => o("CRC32C"),
                        _ => variant_name.to_shouty_snake_case(),
                    };

                    let value = variant.traits.enum_value().unwrap().to_owned();
                    assert!(value.is_ascii());

                    let variant = rust::UnitEnumVariant {
                    let variant = rust::StrEnumVariant {
                        name,
                        value,
                        doc: variant.traits.doc().map(o),
                    };
                    variants.push(variant);
                }
                let ty = rust::Type::UnitEnum(rust::UnitEnum {
                let ty = rust::Type::StrEnum(rust::StrEnum {
                    name: name.clone(),
                    variants,
                    doc: shape.traits.doc().map(o),
@@ -267,7 +267,9 @@ pub fn codegen(rust_types: &RustTypes, g: &mut Codegen) {
        "",
        "use super::*;",
        "",
        "use std::borrow::Cow;",
        "use std::str::FromStr;",
        "use std::convert::Infallible;",
        "",
    ];

@@ -290,59 +292,53 @@ pub fn codegen(rust_types: &RustTypes, g: &mut Codegen) {
                codegen_doc(ty.doc.as_deref(), g);
                g.ln(f!("pub type {} = Map<{}, {}>;", ty.name, ty.key_type, ty.value_type));
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                codegen_doc(ty.doc.as_deref(), g);
                g.ln("#[derive(Debug, Clone, Copy, PartialEq, Eq)]");
                g.ln("#[non_exhaustive]");
                g.ln(f!("pub enum {} {{", ty.name));
                g.ln("#[derive(Debug, Clone, PartialEq, Eq)]");
                g.ln(f!("pub struct {}(Cow<'static, str>);", ty.name));
                g.lf();

                g.ln(f!("impl {} {{", ty.name));
                {
                    for variant in &ty.variants {
                        codegen_doc(variant.doc.as_deref(), g);
                    g.ln(f!("    {},", variant.name));
                        g.ln(f!("pub const {}: &str = \"{}\";", variant.name, variant.value));
                        g.lf();
                    }

                    g.ln("#[must_use]");
                    g.ln("pub fn as_str(&self) -> &str {");
                    g.ln("&self.0");
                    g.ln("}");

                    g.lf();
                g.ln(f!("impl {} {{", ty.name));

                {
                    g.ln("#[must_use]");
                    g.ln("pub const fn as_str(&self) -> &'static str {");

                    g.ln("match self {");
                    for variant in &ty.variants {
                        g.ln(f!("    Self::{} => \"{}\",", variant.name, variant.value));
                    }
                    g.ln("}");

                    g.ln("pub fn from_static(s: &'static str) -> Self {");
                    g.ln("Self(Cow::from(s))");
                    g.ln("}");
                    g.lf();
                }

                {
                    g.ln("#[must_use]");
                    g.ln("pub const fn from_bytes(s: &[u8]) -> Option<Self> {");

                    g.ln("match s {");
                    for variant in &ty.variants {
                        g.ln(f!("b\"{}\" => Some(Self::{}),", variant.value, variant.name));
                    }
                    g.ln("_ => None,");
                g.ln("}");
                g.lf();

                g.ln(f!("impl From<String> for {} {{", ty.name));
                g.ln("fn from(s: String) -> Self {");
                g.ln("Self(Cow::from(s))");
                g.ln("}");
                g.ln("}");
                g.lf();
                }

                g.ln(f!("impl From<{}> for Cow<'static, str> {{", ty.name));
                g.ln(f!("fn from(s: {}) -> Self {{", ty.name));
                g.ln("s.0");
                g.ln("}");
                g.ln("}");

                g.lf();

                g.ln(f!("impl FromStr for {} {{", ty.name));
                g.ln("type Err = ParseEnumError;");
                g.ln("type Err = Infallible;");
                g.ln("fn from_str(s: &str) -> Result<Self, Self::Err> {");
                g.ln("Self::from_bytes(s.as_bytes()).ok_or(ParseEnumError(()))");
                g.ln("Ok(Self::from(s.to_owned()))");
                g.ln("}");
                g.ln("}");
            }
+22 −29
Original line number Diff line number Diff line
@@ -101,6 +101,7 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
        "use crate::path::S3Path;",
        "",
        "use std::io::Write;",
        "use std::borrow::Cow;",
        "",
    ];

@@ -206,7 +207,7 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                field_type_names.insert(ty.member.type_.as_str());
                q.push_back(ty.member.type_.as_str());
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                field_type_names.insert(ty.name.as_str());
            }
            rust::Type::StructEnum(ty) => {
@@ -273,7 +274,7 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                g.ln("}");
                g.ln("}");
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                g.ln(f!("impl xml::SerializeContent for {} {{", ty.name));
                g.ln("fn serialize_content<W: Write>(&self, s: &mut xml::Serializer<W>) -> xml::SerResult {");

@@ -381,7 +382,7 @@ fn codegen_xml_de(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                field_type_names.insert(ty.member.type_.as_str());
                q.push_back(ty.member.type_.as_str());
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                field_type_names.insert(ty.name.as_str());
            }
            rust::Type::StructEnum(ty) => {
@@ -469,12 +470,10 @@ fn codegen_xml_de(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
                g.ln("}");
                g.ln("}");
            }
            rust::Type::UnitEnum(ty) => {
            rust::Type::StrEnum(ty) => {
                g.ln(f!("impl<'xml> xml::DeserializeContent<'xml> for {} {{", ty.name));
                g.ln("fn deserialize_content(d: &mut xml::Deserializer<'xml>) -> xml::DeResult<Self> {");

                g.ln("d.text(|t|Self::from_bytes(t.as_ref()).ok_or(xml::DeError::InvalidContent))");

                g.ln("String::deserialize_content(d).map(Self::from)");
                g.ln("}");
                g.ln("}");
            }
@@ -548,7 +547,7 @@ fn codegen_http(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
}

fn codegen_header_value(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
    let mut unit_enum_names: BTreeSet<&str> = default();
    let mut str_enum_names: BTreeSet<&str> = default();

    for op in ops.values() {
        for ty_name in [op.input.as_str(), op.output.as_str()] {
@@ -561,12 +560,12 @@ fn codegen_header_value(ops: &Operations, rust_types: &RustTypes, g: &mut Codege
                        match field_type {
                            rust::Type::List(list_ty) => {
                                let member_type = &rust_types[list_ty.member.type_.as_str()];
                                if let rust::Type::UnitEnum(ty) = member_type {
                                    unit_enum_names.insert(ty.name.as_str());
                                if let rust::Type::StrEnum(ty) = member_type {
                                    str_enum_names.insert(ty.name.as_str());
                                }
                            }
                            rust::Type::UnitEnum(ty) => {
                                unit_enum_names.insert(ty.name.as_str());
                            rust::Type::StrEnum(ty) => {
                                str_enum_names.insert(ty.name.as_str());
                            }
                            rust::Type::Alias(_) => {}
                            rust::Type::Provided(_) => {}
@@ -580,35 +579,29 @@ fn codegen_header_value(ops: &Operations, rust_types: &RustTypes, g: &mut Codege
        }
    }

    for rust_type in unit_enum_names.iter().map(|&x| &rust_types[x]) {
        let rust::Type::UnitEnum(ty) = rust_type else { panic!() };
    for rust_type in str_enum_names.iter().map(|&x| &rust_types[x]) {
        let rust::Type::StrEnum(ty) = rust_type else { panic!() };

        g.ln(f!("impl http::TryIntoHeaderValue for {} {{", ty.name));
        g.ln("type Error = std::convert::Infallible;");
        g.ln("type Error = http::InvalidHeaderValue;");
        g.ln("fn try_into_header_value(self) -> Result<http::HeaderValue, Self::Error> {");
        g.ln("match self {");
        for variant in &ty.variants {
            g.ln(f!("Self::{0} => {{", variant.name));

            // TODO(blocking): inline_const https://github.com/rust-lang/rust/pull/104087
            // g.ln("Ok(const { http::HeaderValue::from_static(\"{}\") })");

            g.ln("Ok(http::HeaderValue::from_static(\"{}\"))");
            g.ln("}");
        }
        g.ln("    match Cow::from(self) {");
        g.ln("        Cow::Borrowed(s) => http::HeaderValue::try_from(s),");
        g.ln("        Cow::Owned(s) => http::HeaderValue::try_from(s),");
        g.ln("    }");
        g.ln("}");
        g.ln("}");
        g.lf();
    }

    for rust_type in unit_enum_names.iter().map(|&x| &rust_types[x]) {
        let rust::Type::UnitEnum(ty) = rust_type else { panic!() };
    for rust_type in str_enum_names.iter().map(|&x| &rust_types[x]) {
        let rust::Type::StrEnum(ty) = rust_type else { panic!() };

        g.ln(f!("impl http::TryFromHeaderValue for {} {{", ty.name));
        g.ln("type Error = http::ParseHeaderError;");
        g.ln("fn try_from_header_value(val: &http::HeaderValue) -> Result<Self, Self::Error> {");
        g.ln("    Self::from_bytes(val.as_bytes()).ok_or(http::ParseHeaderError::Enum)");
        g.ln("    let val = val.to_str().map_err(|_|http::ParseHeaderError::Enum)?;");
        g.ln("    Ok(Self::from(val.to_owned()))");
        g.ln("}");
        g.ln("}");
        g.lf();
+4 −4
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ pub enum Type {
    Provided(Provided),
    List(List),
    Map(Map),
    UnitEnum(UnitEnum),
    StrEnum(StrEnum),
    Struct(Struct),
    StructEnum(StructEnum),
    Timestamp(Timestamp),
@@ -49,14 +49,14 @@ pub struct Map {
}

#[derive(Debug, Clone)]
pub struct UnitEnum {
pub struct StrEnum {
    pub name: String,
    pub variants: Vec<UnitEnumVariant>,
    pub variants: Vec<StrEnumVariant>,
    pub doc: Option<String>,
}

#[derive(Debug, Clone)]
pub struct UnitEnumVariant {
pub struct StrEnumVariant {
    pub name: String,
    pub value: String,
    pub doc: Option<String>,
+363 −600

File changed.

Preview size limit exceeded, changes collapsed.

Loading