Unverified Commit a42f0a3f authored by Nugine's avatar Nugine
Browse files

codegen: dto: builders

parent 9895238d
Loading
Loading
Loading
Loading
+1 −9
Original line number Diff line number Diff line
use crate::dto::RustTypes;
use crate::ops::is_op_input;
use crate::ops::Operations;
use crate::rust;

@@ -220,15 +221,6 @@ fn aws_ty_path(name: &str, ops: &Operations, rust_types: &RustTypes) -> String {
    f!("aws_sdk_s3::types::{aws_name}")
}

fn is_op_input(name: &str, ops: &Operations) -> bool {
    if let Some(op) = name.strip_suffix("Input") {
        if ops.contains_key(op) {
            return true;
        }
    }
    false
}

fn contains_deprecated_field(name: &str) -> bool {
    matches!(name, "LifecycleRule" | "ReplicationRule")
}
+129 −6
Original line number Diff line number Diff line
use crate::o;
use crate::ops::Operations;
use crate::ops::{is_op_input, Operations};
use crate::rust::codegen_doc;
use crate::{rust, smithy};

use std::borrow::Cow;
use std::collections::BTreeMap;
use std::ops::Not;

@@ -341,6 +342,7 @@ pub fn codegen(rust_types: &RustTypes, ops: &Operations) {
        "use std::fmt;"
        "use std::str::FromStr;"
        ""
        "use rust_utils::default::default;"
    ];

    for rust_type in rust_types.values() {
@@ -362,7 +364,7 @@ pub fn codegen(rust_types: &RustTypes, ops: &Operations) {
                codegen_str_enum(ty, rust_types);
            }
            rust::Type::Struct(ty) => {
                codegen_struct(ty, rust_types);
                codegen_struct(ty, rust_types, ops);
            }
            rust::Type::StructEnum(ty) => {
                codegen_struct_enum(ty, rust_types);
@@ -376,9 +378,10 @@ pub fn codegen(rust_types: &RustTypes, ops: &Operations) {
    }

    codegen_tests(ops);
    codegen_builders(rust_types, ops);
}

fn codegen_struct(ty: &rust::Struct, rust_types: &RustTypes) {
fn codegen_struct(ty: &rust::Struct, rust_types: &RustTypes, ops: &Operations) {
    codegen_doc(ty.doc.as_deref());
    if can_derive_default(ty, rust_types) {
        g!("#[derive(Default)]");
@@ -412,6 +415,18 @@ fn codegen_struct(ty: &rust::Struct, rust_types: &RustTypes) {
    g!("d.finish_non_exhaustive()");
    g!("}}");
    g!("}}");
    g!();

    if is_op_input(&ty.name, ops) {
        g!("impl {} {{", ty.name);

        g!("#[must_use]");
        g!("pub fn builder() -> builders::{}Builder {{", ty.name);
        g!("default()");
        g!("}}");

        g!("}}");
    }
}

fn codegen_str_enum(ty: &rust::StrEnum, _rust_types: &RustTypes) {
@@ -509,14 +524,16 @@ fn can_derive_default(ty: &rust::Struct, rust_types: &RustTypes) -> bool {
    ty.fields.iter().all(|field| is_default_field(field, rust_types))
}

fn is_default_field(field: &rust::StructField, rust_types: &RustTypes) -> bool {
    let is_rust_default = |v: &Value| match v {
fn is_rust_default(v: &Value) -> bool {
    match v {
        Value::Bool(x) => !x,
        Value::Number(x) => x.as_i64() == Some(0),
        Value::String(x) => x.is_empty(),
        _ => unimplemented!("{v:#?}"),
    };
    }
}

fn is_default_field(field: &rust::StructField, rust_types: &RustTypes) -> bool {
    if field.option_type {
        return true;
    }
@@ -527,3 +544,109 @@ fn is_default_field(field: &rust::StructField, rust_types: &RustTypes) -> bool {

    field.default_value.as_ref().map_or(false, is_rust_default)
}

fn codegen_builders(rust_types: &RustTypes, ops: &Operations) {
    glines!(
        "pub mod builders {" //
        "#![allow(clippy::missing_errors_doc)]"
        ""
        "use super::*;"
        "pub use super::build_error::BuildError;"
        ""
    );

    for op in ops.values() {
        let rust::Type::Struct(ty) = &rust_types[&op.input] else {continue};
        codegen_struct_builder(ty, rust_types);
        g!();
    }

    g!("}}");
}

fn is_list_or_map(name: &str, rust_types: &RustTypes) -> bool {
    matches!(&rust_types[name], rust::Type::List(_) | rust::Type::Map(_))
}

fn codegen_struct_builder(ty: &rust::Struct, rust_types: &RustTypes) {
    g!("/// A builder for [`{}`]", ty.name);

    g!("#[derive(Default)]");
    g!("pub struct {}Builder {{", ty.name);

    for field in &ty.fields {
        if field.option_type {
            g!("{}: Option<{}>,", field.name, field.type_);
            g!();
            continue;
        }

        if is_list_or_map(&field.type_, rust_types) {
            g!("{}: {},", field.name, field.type_);
            g!();
            continue;
        }

        if let Some(ref v) = field.default_value {
            assert!(is_rust_default(v));
            g!("{}: {},", field.name, field.type_);
            g!();
            continue;
        }

        g!("{}: Option<{}>,", field.name, field.type_);
        g!();
    }

    g!("}}");
    g!();

    g!("impl {}Builder {{", ty.name);

    for field in &ty.fields {
        let field_name = field.name.as_str();

        let struct_field_type = if field.option_type {
            Cow::Owned(format!("Option<{}>", field.type_))
        } else {
            Cow::Borrowed(&field.type_)
        };

        g!("pub fn set_{field_name}(&mut self, field: {struct_field_type}) {{");

        if field.option_type || field.default_value.is_some() || is_list_or_map(&field.type_, rust_types) {
            g!("    self.{field_name} = field;");
        } else {
            g!("    self.{field_name} = Some(field);");
        }

        g!("}}");
        g!();
    }

    {
        g!("pub fn build(self) -> Result<{}, BuildError> {{", ty.name);

        for field in &ty.fields {
            let field_name = field.name.as_str();

            if field.option_type || field.default_value.is_some() || is_list_or_map(&field.type_, rust_types) {
                g!("let {field_name} = self.{field_name};");
            } else {
                g!("let {field_name} = self.{field_name}.ok_or_else(|| BuildError::missing_field({field_name:?}))?;");
            }
        }

        g!("Ok({} {{", ty.name);
        for field in &ty.fields {
            g!("{},", field.name);
        }
        g!("}})");

        g!("}}");
        g!();
    }

    g!("}}");
    g!();
}
+4 −0
Original line number Diff line number Diff line
@@ -98,6 +98,10 @@ pub fn collect_operations(model: &smithy::Model) -> Operations {
    operations
}

pub fn is_op_input(name: &str, ops: &Operations) -> bool {
    name.strip_suffix("Input").map_or(false, |x| ops.contains_key(x))
}

pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
    glines![
        "//! Auto generated by `codegen/src/ops.rs`"
+28 −0
Original line number Diff line number Diff line
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct BuildError {
    #[from]
    kind: BuildErrorKind,
}

#[derive(Debug, thiserror::Error)]
enum BuildErrorKind {
    #[error("Missing field: {field:?}")]
    MissingField { field: &'static str },
    // #[error("BuildError: {source}")]
    // Other { source: StdError },
}

impl BuildError {
    pub(crate) fn missing_field(field: &'static str) -> Self {
        Self {
            kind: BuildErrorKind::MissingField { field },
        }
    }

    // pub(crate) fn other(source: StdError) -> Self {
    //     Self {
    //         kind: BuildErrorKind::Other { source },
    //     }
    // }
}
+6432 −0

File changed.

Preview size limit exceeded, changes collapsed.

Loading