Unverified Commit 7f40dfe2 authored by Nugine's avatar Nugine
Browse files

codegen: aws

parent 48be48c7
Loading
Loading
Loading
Loading

codegen/src/aws.rs

0 → 100644
+189 −0
Original line number Diff line number Diff line
use crate::dto::RustTypes;
use crate::f;
use crate::gen::Codegen;
use crate::ops::Operations;
use crate::rust;

pub fn codegen(ops: &Operations, rust_types: &RustTypes, g: &mut Codegen) {
    g.ln("use super::*;");
    g.lf();
    g.ln("use aws_sdk_s3::input::*;");
    g.ln("use aws_sdk_s3::output::*;");
    g.ln("use aws_sdk_s3::model::*;");
    g.ln("use aws_sdk_s3::error::*;");
    g.lf();

    for (name, rust_type) in rust_types {
        match rust_type {
            rust::Type::Alias(_) => continue,
            rust::Type::Provided(_) => continue,
            rust::Type::Timestamp(_) => continue,
            rust::Type::List(_) => continue,
            rust::Type::Map(_) => continue,
            rust::Type::UnitEnum(_) => {}
            rust::Type::Struct(_) => {}
            rust::Type::StructEnum(_) => {}
        }

        let s3s_path = f!("crate::dto::{name}");
        let aws_name = aws_ty_name(name);

        g.ln(f!("impl AwsConversion for {s3s_path} {{"));
        g.ln(f!("    type Target = {aws_name};"));
        g.lf();

        if contains_deprecated_field(name) {
            g.ln("#[allow(deprecated)]");
        }
        g.ln("fn try_from_aws(x: Self::Target) -> S3Result<Self> {");
        match rust_type {
            rust::Type::Struct(ty) => {
                if ty.fields.is_empty() {
                    g.ln("let _ = x;");
                }

                g.ln("Ok(Self {");
                for field in &ty.fields {
                    let s3s_field_name = field.name.as_str();
                    let aws_field_name = match s3s_field_name {
                        "checksum_crc32c" => "checksum_crc32_c",
                        "type_" => "r#type",
                        s => s,
                    };
                    let field_type = &rust_types[field.type_.as_str()];

                    'gen: {
                        if let rust::Type::Provided(ty) = field_type {
                            match ty.name.as_str() {
                                "StreamingBlob" => {
                                    g.ln(f!("{s3s_field_name}: stream_from_aws(x.{aws_field_name}),"));
                                    break 'gen;
                                }
                                "Body" => {}
                                "ContentType" | "CopySource" | "Range" => {
                                    // typed header value
                                }
                                _ => unimplemented!("{ty:#?}"),
                            }
                        }

                        if field.option_type || field.default_value.is_some() {
                            g.ln(f!("{s3s_field_name}: try_from_aws(x.{aws_field_name})?,"));
                        } else {
                            g.ln(f!(
                                "{s3s_field_name}: unwrap_from_aws(x.{aws_field_name}, \"{s3s_field_name}\")?,"
                            ));
                        }
                    }
                }
                g.ln("})");
            }
            rust::Type::UnitEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("{aws_name}::{0} => Self::{0},", variant.name));
                }
                g.ln("_ => unreachable!(),");
                g.ln("})");
            }
            rust::Type::StructEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("{aws_name}::{0}(v) => Self::{0}(try_from_aws(v)?),", variant.name));
                }
                g.ln("_ => unreachable!(),");
                g.ln("})");
            }
            _ => panic!(),
        }
        g.ln("}");
        g.lf();

        if contains_deprecated_field(name) {
            g.ln("#[allow(deprecated)]");
        }
        g.ln("fn try_into_aws(x: Self) -> S3Result<Self::Target> {");
        match rust_type {
            rust::Type::Struct(ty) => {
                if ty.fields.is_empty() {
                    g.ln("let _ = x;");
                    g.ln("let y = Self::Target::builder();");
                } else {
                    g.ln("let mut y = Self::Target::builder();");
                }

                for field in &ty.fields {
                    let s3s_field_name = field.name.as_str();
                    let aws_field_name = match s3s_field_name {
                        "checksum_crc32c" => "checksum_crc32_c",
                        "type_" => "type",
                        s => s,
                    };

                    if field.option_type {
                        g.ln(f!("y = y.set_{aws_field_name}(try_into_aws(x.{s3s_field_name})?);"));
                    } else {
                        g.ln(f!("y = y.set_{aws_field_name}(Some(try_into_aws(x.{s3s_field_name})?));"));
                    }
                }

                if is_op_input(&ty.name, ops) {
                    g.ln("y.build().map_err(S3Error::internal_error)");
                } else {
                    g.ln("Ok(y.build())");
                }
            }
            rust::Type::UnitEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("Self::{0} => {aws_name}::{0},", variant.name));
                }
                g.ln("})");
            }
            rust::Type::StructEnum(ty) => {
                g.ln("Ok(match x {");
                for variant in &ty.variants {
                    g.ln(f!("Self::{0}(v) => {aws_name}::{0}(try_into_aws(v)?),", variant.name));
                }
                g.ln("})");
            }
            _ => panic!(),
        }
        g.ln("}");

        g.ln("}");
        g.lf();
    }
}

fn aws_ty_name(name: &str) -> &str {
    match name {
        "BucketCannedACL" => "BucketCannedAcl",
        "CORSConfiguration" => "CorsConfiguration",
        "CORSRule" => "CorsRule",
        "CSVInput" => "CsvInput",
        "CSVOutput" => "CsvOutput",
        "JSONInput" => "JsonInput",
        "JSONOutput" => "JsonOutput",
        "JSONType" => "JsonType",
        "MFADelete" => "MfaDelete",
        "MFADeleteStatus" => "MfaDeleteStatus",
        "ObjectCannedACL" => "ObjectCannedAcl",
        "SSEKMS" => "Ssekms",
        "SSES3" => "Sses3",
        _ => name,
    }
}

fn is_op_input(name: &str, ops: &Operations) -> bool {
    if let Some(op) = name.strip_suffix("Input") {
        if ops.contains_key(op) {
            return true;
        }
    }
    false
}

fn contains_deprecated_field(name: &str) -> bool {
    matches!(name, "LifecycleRule" | "ReplicationRule")
}
+8 −0
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@ mod error;
mod headers;
mod ops;

mod aws;

use crate::gen::Codegen;

use std::format as f;
@@ -61,4 +63,10 @@ fn main() {
        let mut gen = Codegen::create_file(path).unwrap();
        ops::codegen(&ops, &rust_types, &mut gen);
    }

    {
        let path = "crates/s3s/src/aws/generated.rs";
        let mut gen = Codegen::create_file(path).unwrap();
        aws::codegen(&ops, &rust_types, &mut gen);
    }
}
+12 −1
Original line number Diff line number Diff line
@@ -9,11 +9,21 @@ repository = "https://github.com/Nugine/s3s"
keywords = ["s3"]
categories = ["web-programming", "web-programming::http-server"]

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[features]
aws = ["dep:aws-sdk-s3", "dep:aws-smithy-types-convert", "dep:aws-smithy-http"]

[dependencies]
arrayvec = "0.7.2"
ascii = { version = "1.1.0", default-features = false }
async-trait = "0.1.62"
atoi = { version = "2.0.0", default-features = false }
aws-sdk-s3 = { version = "0.23.0", optional = true }
aws-smithy-http = { version = "0.53.1", optional = true }
aws-smithy-types-convert = { version = "0.53.1", optional = true, features = ["convert-time"] }
base64-simd = "0.8.0"
bytes = "1.3.0"
chrono = { version = "0.4.23", default-features = false }
@@ -21,7 +31,8 @@ futures = { version = "0.3.25", default-features = false, features = ["std"] }
hex-simd = "0.8.0"
hmac = "0.12.1"
httparse = "1.8.0"
hyper = { version = "0.14.23", features = ["stream"] } # TODO(blocking): upgrade to hyper 1.0.0
# TODO(blocking): upgrade to hyper 1.0.0
hyper = { version = "0.14.23", features = ["stream"] }
memchr = "2.5.0"
mime = "0.3.16"
nom = "7.1.3"
+7701 −0

File added.

Preview size limit exceeded, changes collapsed.

+177 −0
Original line number Diff line number Diff line
mod generated;

use crate::error::{S3Error, S3Result};

use std::collections::HashMap;

pub trait AwsConversion: Sized {
    type Target;

    fn try_from_aws(x: Self::Target) -> S3Result<Self>;

    fn try_into_aws(x: Self) -> S3Result<Self::Target>;
}

pub fn try_from_aws<T: AwsConversion>(x: T::Target) -> S3Result<T> {
    T::try_from_aws(x)
}

pub fn try_into_aws<T: AwsConversion>(x: T) -> S3Result<T::Target> {
    T::try_into_aws(x)
}

fn unwrap_from_aws<T: AwsConversion>(opt: Option<T::Target>, field_name: &str) -> S3Result<T> {
    match opt {
        Some(x) => T::try_from_aws(x),
        None => Err(s3_error!(InternalError, "missing field: {}", field_name)),
    }
}

macro_rules! identity_impl {
    ($($ty:ty),+) => {
        $(
            impl AwsConversion for $ty {
                type Target = $ty;

                #[inline(always)]
                fn try_from_aws(x: Self::Target) -> S3Result<Self> {
                    Ok(x)
                }

                #[inline(always)]
                fn try_into_aws(x: Self) -> S3Result<Self::Target> {
                    Ok(x)
                }
            }
        )+
    };
}

identity_impl!(bool, i32, i64, String, HashMap<String, String>);

impl<T: AwsConversion> AwsConversion for Option<T> {
    type Target = Option<T::Target>;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        x.map(try_from_aws).transpose()
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        x.map(try_into_aws).transpose()
    }
}

impl<T: AwsConversion> AwsConversion for Vec<T> {
    type Target = Vec<T::Target>;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        x.into_iter().map(try_from_aws).collect()
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        x.into_iter().map(try_into_aws).collect()
    }
}

impl AwsConversion for crate::dto::Timestamp {
    type Target = aws_sdk_s3::types::DateTime;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        use aws_smithy_types_convert::date_time::DateTimeExt;
        Ok(Self::from(x.to_time().map_err(S3Error::internal_error)?))
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        use aws_smithy_types_convert::date_time::DateTimeExt;
        Ok(aws_sdk_s3::types::DateTime::from_time(x.into()))
    }
}

impl AwsConversion for crate::dto::ContentType {
    type Target = String;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        x.parse::<Self>().map_err(S3Error::internal_error)
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(x.to_string())
    }
}

impl AwsConversion for crate::dto::CopySource {
    type Target = String;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        Self::parse(x.as_str()).map_err(S3Error::internal_error)
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(x.format_to_string())
    }
}

impl AwsConversion for crate::dto::Range {
    type Target = String;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        Self::parse(x.as_str()).map_err(S3Error::internal_error)
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(x.format_to_string())
    }
}

impl AwsConversion for crate::dto::Event {
    type Target = aws_sdk_s3::model::Event;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        Ok(Self::from(x.as_str().to_owned()))
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(Self::Target::from(x))
    }
}

fn stream_from_aws(x: aws_sdk_s3::types::ByteStream) -> Option<crate::dto::StreamingBlob> {
    Some(crate::dto::StreamingBlob::wrap(x))
}

impl AwsConversion for crate::dto::StreamingBlob {
    type Target = aws_sdk_s3::types::ByteStream;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        Ok(Self::wrap(x))
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(hyper::Body::wrap_stream(x).into())
    }
}

impl AwsConversion for crate::dto::Body {
    type Target = aws_sdk_s3::types::Blob;

    fn try_from_aws(x: Self::Target) -> S3Result<Self> {
        Ok(x.into_inner().into())
    }

    fn try_into_aws(x: Self) -> S3Result<Self::Target> {
        Ok(Self::Target::new(x))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn mime() {
        let x: crate::dto::ContentType = mime::TEXT_XML;
        let y = try_into_aws(x.clone()).unwrap();
        let z: crate::dto::ContentType = try_from_aws(y).unwrap();
        assert_eq!(x, z);
    }
}
Loading