Unverified Commit 0677bfc5 authored by Nugine's avatar Nugine
Browse files

codegen: ops: unify type name

parent ddc3eda7
Loading
Loading
Loading
Loading
+40 −15
Original line number Diff line number Diff line
@@ -17,8 +17,8 @@ pub fn to_type_name(shape_name: &str) -> &str {
pub type RustTypes = BTreeMap<String, rust::Type>;

pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes {
    let mut ans: BTreeMap<String, rust::Type> = default();
    let mut insert = |k: String, v: rust::Type| assert!(ans.insert(k, v).is_none());
    let mut space: BTreeMap<String, rust::Type> = default();
    let mut insert = |k: String, v: rust::Type| assert!(space.insert(k, v).is_none());

    for (shape_name, shape) in &model.shapes {
        let name = to_type_name(shape_name).to_owned();
@@ -123,17 +123,6 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
                insert(name, ty);
            }
            smithy::Shape::Structure(shape) => {
                let name = match name.strip_suffix("Request") {
                    Some(op_name) => {
                        if ops.contains_key(op_name) {
                            f!("{op_name}Input")
                        } else {
                            name
                        }
                    }
                    None => name,
                };

                let mut fields = Vec::new();
                for (field_name, field) in &shape.members {
                    let name = if field_name == "Type" {
@@ -230,9 +219,45 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
        }
    }

    ans.insert(o("Unit"), rust::Type::provided("Unit"));
    // unify operation input type
    for op in ops.values() {
        let input_ty = if op.smithy_input == "Unit" {
            rust::Struct {
                name: op.input.clone(),
                fields: default(),
                doc: None,
                xml_name: None,
            }
        } else {
            assert!(op.smithy_input.ends_with("Request"));
            let Some(rust::Type::Struct(mut ty)) = space.remove(&op.smithy_input) else { panic!() };
            ty.name = op.input.clone(); // rename type
            ty
        };
        assert!(space.insert(op.input.clone(), rust::Type::Struct(input_ty)).is_none());
    }

    // unify operation output type
    for op in ops.values() {
        let output_ty = if op.smithy_output == "Unit" {
            rust::Struct {
                name: op.output.clone(),
                fields: default(),
                doc: None,
                xml_name: None,
            }
        } else {
            if op.smithy_output == op.output {
                continue;
            }
            let rust::Type::Struct(mut ty) = space[&op.smithy_output].clone() else { panic!() };
            ty.name = op.output.clone(); // duplicate type
            ty
        };
        assert!(space.insert(op.output.clone(), rust::Type::Struct(output_ty)).is_none());
    }

    ans
    space
}

pub fn codegen(rust_types: &RustTypes, g: &mut Codegen) {
+265 −273
Original line number Diff line number Diff line
@@ -14,8 +14,13 @@ use serde_json::Value;
#[derive(Debug)]
pub struct Operation {
    pub name: String,

    pub input: String,
    pub output: String,

    pub smithy_input: String,
    pub smithy_output: String,

    pub doc: Option<String>,

    pub http_method: String,
@@ -32,42 +37,52 @@ pub fn collect_operations(model: &smithy::Model) -> Operations {
    for (shape_name, shape) in &model.shapes {
        let smithy::Shape::Operation(sh) = shape else { continue };

        let name = dto::to_type_name(shape_name).to_owned();
        if name == "SelectObjectContent" {
        let op_name = dto::to_type_name(shape_name).to_owned();
        if op_name == "SelectObjectContent" {
            continue; // TODO(further): impl SelectObjectContent
        }

        let input = {
            let sn = sh.input.target.as_str();
            if sn == "smithy.api#Unit" {
        let cvt = |n| {
            if n == "smithy.api#Unit" {
                o("Unit")
            } else {
                assert_eq!(dto::to_type_name(sn).strip_suffix("Request").unwrap(), name);
                f!("{name}Input")
                o(dto::to_type_name(n))
            }
        };

        let smithy_input = cvt(sh.input.target.as_str());
        let smithy_output = cvt(sh.output.target.as_str());

        let input = {
            if smithy_input != "Unit" {
                assert_eq!(smithy_input.strip_suffix("Request").unwrap(), op_name);
            }
            f!("{op_name}Input")
        };

        let output = {
            let sn = sh.output.target.as_str();
            if sn == "smithy.api#Unit" {
                o("Unit")
            } else {
                o(dto::to_type_name(sn))
            if smithy_output != "Unit" && smithy_output != "NotificationConfiguration" {
                assert_eq!(smithy_output.strip_suffix("Output").unwrap(), op_name);
            }
            f!("{op_name}Output")
        };

        let op = Operation {
            name: name.clone(),
            name: op_name.clone(),

            input,
            output,

            smithy_input,
            smithy_output,

            doc: sh.traits.doc().map(o),

            http_method: sh.traits.http_method().unwrap().to_owned(),
            http_uri: sh.traits.http_uri().unwrap().to_owned(),
            http_code: sh.traits.http_code().unwrap(),
        };
        insert(name, op);
        insert(op_name, op);
    }

    operations
@@ -116,25 +131,11 @@ fn codegen_async_trait(ops: &Operations, g: &mut Codegen) {

        let method_name = op.name.to_snake_case();

        let input_is_unit = op.input == "Unit";
        let output_is_unit = op.output == "Unit";

        match (input_is_unit, output_is_unit) {
            (false, false) => {
        g.ln(f!(
            "async fn {method_name}(&self, _input: {}) -> S3Result<{}> {{",
            op.input,
            op.output
        ));
            }
            (false, true) => {
                g.ln(f!("async fn {method_name}(&self, _input: {}) -> S3Result {{", op.input));
            }
            (true, false) => {
                g.ln(f!("async fn {method_name}(&self) -> S3Result<{}> {{", op.output));
            }
            (true, true) => panic!(),
        }

        g.ln(f!("Err(s3_error!(NotImplemented, \"{} is not implemented yet\"))", op.name));
        g.ln("}");
@@ -163,10 +164,6 @@ fn codegen_xml_ser(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
    for op in ops.values() {
        let ty_name = op.output.as_str();

        if ty_name == "Unit" {
            continue;
        }

        let rust_type = &rust_types[ty_name];
        let rust::Type::Struct(ty) = rust_type else { panic!() };

@@ -350,9 +347,6 @@ fn codegen_xml_de(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {

    for op in ops.values() {
        let ty_name = op.input.as_str();
        if ty_name == "Unit" {
            continue;
        }

        let rust_type = &rust_types[ty_name];
        let rust::Type::Struct(ty) = rust_type else { panic!() };
@@ -658,6 +652,18 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)
            g.ln("}");
        }
        rust::Type::Struct(ty) => {
            if ty.fields.is_empty() {
                g.ln(f!("pub fn serialize_http(_: {output}) -> S3Result<http::Response> {{"));
                if op.http_code == 200 {
                    g.ln("Ok(http::Response::default())");
                } else {
                    g.ln("let mut res = http::Response::default();");
                    let code_name = status_code_name(op.http_code);
                    g.ln(f!("*res.status_mut() = http::StatusCode::{code_name};"));
                    g.ln("Ok(res)");
                }
                g.ln("}");
            } else {
                g.ln(f!("pub fn serialize_http(x: {output}) -> S3Result<http::Response> {{"));

                assert!(ty.fields.is_empty().not());
@@ -734,6 +740,7 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes, g: &mut Codegen)

                g.ln("}");
            }
        }
        _ => unimplemented!(),
    }
    g.lf();
@@ -747,6 +754,11 @@ fn codegen_op_http_de(op: &Operation, rust_types: &RustTypes, g: &mut Codegen) {
            assert_eq!(ty.name, "Unit");
        }
        rust::Type::Struct(ty) => {
            if ty.fields.is_empty() {
                g.ln(f!("pub fn deserialize_http(_: &mut http::Request) -> S3Result<{input}> {{"));
                g.ln(f!("Ok({input} {{}})"));
                g.ln("}");
            } else {
                g.ln(f!("pub fn deserialize_http(req: &mut http::Request) -> S3Result<{input}> {{"));

                if op.name == "PutObject" {
@@ -930,6 +942,7 @@ fn codegen_op_http_de(op: &Operation, rust_types: &RustTypes, g: &mut Codegen) {
                    codegen_op_http_de_multipart(op, rust_types, g);
                }
            }
        }
        _ => unimplemented!(),
    }
    g.lf();
@@ -1029,26 +1042,16 @@ fn codegen_op_http_call(op: &Operation, g: &mut Codegen) {
    g.ln("}");
    g.lf();

    let arg = if op.input != "Unit" { "req" } else { "_" };
    g.ln(f!(
        "async fn call(&self, s3: &dyn S3, {arg}: &mut http::Request) -> S3Result<http::Response> {{"
    ));
    g.ln("async fn call(&self, s3: &dyn S3, req: &mut http::Request) -> S3Result<http::Response> {");

    let method = op.name.to_snake_case();

    if op.input != "Unit" {
    g.ln("let input = Self::deserialize_http(req)?;");
    g.ln(f!("let result = s3.{method}(input).await;"));
    } else {
        g.ln(f!("let result = s3.{method}().await;"));
    }

    g.ln("let res = match result {");
    if op.output != "Unit" {
    g.ln("Ok(output) => Self::serialize_http(output)?,");
    } else {
        g.ln("Ok(()) => Self::serialize_http(),");
    }

    g.ln("Err(err) => super::serialize_error(err)?,");
    g.ln("};");

@@ -1142,10 +1145,6 @@ fn collect_routes<'a>(ops: &'a Operations, rust_types: &'a RustTypes) -> HashMap
}

fn required_headers<'a>(op: &Operation, rust_types: &'a RustTypes) -> Vec<&'a str> {
    if op.input == "Unit" {
        return default();
    }

    let input_type = &rust_types[op.input.as_str()];
    let rust::Type::Struct(ty) = input_type else { panic!() };

@@ -1161,10 +1160,6 @@ fn required_headers<'a>(op: &Operation, rust_types: &'a RustTypes) -> Vec<&'a st
}

fn required_query_strings<'a>(op: &Operation, rust_types: &'a RustTypes) -> Vec<&'a str> {
    if op.input == "Unit" {
        return default();
    }

    let input_type = &rust_types[op.input.as_str()];
    let rust::Type::Struct(ty) = input_type else { panic!() };

@@ -1180,9 +1175,6 @@ fn required_query_strings<'a>(op: &Operation, rust_types: &'a RustTypes) -> Vec<
}

fn needs_full_body(op: &Operation, rust_types: &RustTypes) -> bool {
    if op.input == "Unit" {
        return false;
    }
    if op.http_method == "GET" {
        return false;
    }
+13 −13
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ use crate::gen::Codegen;

use serde_json::Value;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Type {
    Alias(Alias),
    Provided(Provided),
@@ -15,32 +15,32 @@ pub enum Type {
    Timestamp(Timestamp),
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Alias {
    pub name: String,
    pub type_: String,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Provided {
    pub name: String,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct List {
    pub name: String,
    pub member: ListMember,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ListMember {
    pub type_: String,
    pub xml_name: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Map {
    pub name: String,
    pub key_type: String,
@@ -48,21 +48,21 @@ pub struct Map {
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnitEnum {
    pub name: String,
    pub variants: Vec<UnitEnumVariant>,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnitEnumVariant {
    pub name: String,
    pub value: String,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Struct {
    pub name: String,
    pub fields: Vec<StructField>,
@@ -71,7 +71,7 @@ pub struct Struct {
    pub xml_name: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructField {
    pub name: String,
    pub type_: String,
@@ -91,21 +91,21 @@ pub struct StructField {
    pub xml_flattened: bool,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructEnum {
    pub name: String,
    pub variants: Vec<StructEnumVariant>,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructEnumVariant {
    pub name: String,
    pub type_: String,
    pub doc: Option<String>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Timestamp {
    pub name: String,
    pub format: Option<String>,
+562 −0

File changed.

Preview size limit exceeded, changes collapsed.

+5 −5
Original line number Diff line number Diff line
@@ -83,10 +83,10 @@ impl S3 for FileSystem {
    }

    #[tracing::instrument]
    async fn delete_bucket(&self, input: DeleteBucketInput) -> S3Result {
    async fn delete_bucket(&self, input: DeleteBucketInput) -> S3Result<DeleteBucketOutput> {
        let path = self.get_bucket_path(&input.bucket)?;
        try_!(fs::remove_dir_all(path).await);
        Ok(())
        Ok(DeleteBucketOutput {})
    }

    #[tracing::instrument]
@@ -208,14 +208,14 @@ impl S3 for FileSystem {
    }

    #[tracing::instrument]
    async fn head_bucket(&self, input: HeadBucketInput) -> S3Result {
    async fn head_bucket(&self, input: HeadBucketInput) -> S3Result<HeadBucketOutput> {
        let path = self.get_bucket_path(&input.bucket)?;

        if !path.exists() {
            return Err(s3_error!(NoSuchBucket));
        }

        Ok(())
        Ok(HeadBucketOutput {})
    }

    #[tracing::instrument]
@@ -246,7 +246,7 @@ impl S3 for FileSystem {
    }

    #[tracing::instrument]
    async fn list_buckets(&self) -> S3Result<ListBucketsOutput> {
    async fn list_buckets(&self, _: ListBucketsInput) -> S3Result<ListBucketsOutput> {
        let mut buckets: Vec<Bucket> = Vec::new();
        let mut iter = try_!(fs::read_dir(&self.root).await);
        while let Some(entry) = try_!(iter.next_entry().await) {
Loading