Unverified Commit c920f74d authored by Nugine's avatar Nugine
Browse files

codegen: dto: derive default for all output types

parent a6134063
Loading
Loading
Loading
Loading
+50 −19
Original line number Diff line number Diff line
@@ -327,7 +327,7 @@ fn unify_operation_types(ops: &Operations, space: &mut RustTypes) {
    }
}

pub fn codegen(rust_types: &RustTypes) {
pub fn codegen(rust_types: &RustTypes, ops: &Operations) {
    glines![
        "//! Auto generated by `codegen/src/dto.rs`"
        ""
@@ -374,27 +374,13 @@ pub fn codegen(rust_types: &RustTypes) {
        }
        g!();
    }
}

fn codegen_struct(ty: &rust::Struct, _rust_types: &RustTypes) {
    let is_rust_default = |v: &Value| match v {
        Value::Bool(x) => !x,
        Value::Number(x) => x.as_i64() == Some(0),
        Value::String(x) => x.is_empty(),
        _ => unimplemented!("{v:#?}"),
    };

    let can_derive_default = ty.fields.iter().all(|field| {
        let is_option = field.option_type;
        let has_default = match field.default_value {
            Some(ref v) => is_rust_default(v),
            None => false,
        };
        is_option || has_default
    });
    codegen_tests(ops);
}

fn codegen_struct(ty: &rust::Struct, rust_types: &RustTypes) {
    codegen_doc(ty.doc.as_deref());
    if can_derive_default {
    if can_derive_default(ty, rust_types) {
        g!("#[derive(Default)]");
    }
    // g!("#[non_exhaustive]"); // TODO: builder?
@@ -496,3 +482,48 @@ fn codegen_struct_enum(ty: &rust::StructEnum, _rust_types: &RustTypes) {

    g!("}}");
}

fn codegen_tests(ops: &Operations) {
    glines!(
        "#[cfg(test)]"
        "mod tests {"
        "use super::*;"
        ""
        "fn require_default<T: Default>() {}"
        ""
    );

    {
        g!("#[test]");
        g!("fn test_default() {{");
        for op in ops.values() {
            g!("require_default::<{}>();", op.output);
        }
        g!("}}");
    }

    g!("}}");
}

fn can_derive_default(ty: &rust::Struct, rust_types: &RustTypes) -> bool {
    ty.fields.iter().all(|field| is_default_field(field, rust_types))
}

fn is_default_field(field: &rust::StructField, rust_types: &RustTypes) -> bool {
    let is_rust_default = |v: &Value| match v {
        Value::Bool(x) => !x,
        Value::Number(x) => x.as_i64() == Some(0),
        Value::String(x) => x.is_empty(),
        _ => unimplemented!("{v:#?}"),
    };

    if field.option_type {
        return true;
    }

    if matches!(&rust_types[&field.type_], rust::Type::List(_)) {
        return true;
    }

    field.default_value.as_ref().map_or(false, is_rust_default)
}
+2 −2
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ fn main() {
        let json_file = std::fs::read(json_path).unwrap();
        serde_json::from_slice(&json_file).unwrap()
    };
    assert!(model.smithy == "2.0");
    assert_eq!(model.smithy, "2.0");

    let ops = ops::collect_operations(&model);
    let rust_types = dto::collect_rust_types(&model, &ops);
@@ -46,7 +46,7 @@ fn main() {
    {
        let path = "crates/s3s/src/dto/generated.rs";
        let gen = Codegen::create_file(path).unwrap();
        codegen_writer::scoped(gen, || dto::codegen(&rust_types));
        codegen_writer::scoped(gen, || dto::codegen(&rust_types, &ops));
    }

    {
+113 −0
Original line number Diff line number Diff line
@@ -478,6 +478,7 @@ pub type BucketKeyEnabled = bool;
/// <p>Specifies the lifecycle configuration for objects in an Amazon S3 bucket. For more
/// information, see <a href="https://docs.aws.amazon.com/AmazonS3/latest/dev/object-lifecycle-mgmt.html">Object Lifecycle Management</a>
/// in the <i>Amazon S3 User Guide</i>.</p>
#[derive(Default)]
pub struct BucketLifecycleConfiguration {
    /// <p>A lifecycle rule for individual objects in an Amazon S3 bucket.</p>
    pub rules: LifecycleRules,
@@ -686,6 +687,7 @@ pub type BytesScanned = i64;
/// <p>Describes the cross-origin access configuration for objects in an Amazon S3 bucket. For more
/// information, see <a href="https://docs.aws.amazon.com/AmazonS3/latest/dev/cors.html">Enabling
/// Cross-Origin Resource Sharing</a> in the <i>Amazon S3 User Guide</i>.</p>
#[derive(Default)]
pub struct CORSConfiguration {
    /// <p>A set of origins and methods (cross-origin access that you want to allow). You can add
    /// up to 100 rules to the configuration.</p>
@@ -701,6 +703,7 @@ impl fmt::Debug for CORSConfiguration {
}
/// <p>Specifies a cross-origin access rule for an Amazon S3 bucket.</p>
#[derive(Default)]
pub struct CORSRule {
    /// <p>Headers that are specified in the <code>Access-Control-Request-Headers</code> header.
    /// These headers are allowed in a preflight OPTIONS request. In response to any preflight
@@ -2306,6 +2309,7 @@ impl fmt::Debug for DefaultRetention {
}
/// <p>Container for the objects to delete.</p>
#[derive(Default)]
pub struct Delete {
    /// <p>The objects to delete.</p>
    pub objects: ObjectIdentifierList,
@@ -6073,6 +6077,7 @@ impl fmt::Debug for GetBucketTaggingInput {
    }
}
#[derive(Default)]
pub struct GetBucketTaggingOutput {
    /// <p>Contains the tag set.</p>
    pub tag_set: TagSet,
@@ -6908,6 +6913,7 @@ impl fmt::Debug for GetObjectTaggingInput {
    }
}
#[derive(Default)]
pub struct GetObjectTaggingOutput {
    /// <p>Contains the tag set.</p>
    pub tag_set: TagSet,
@@ -10808,6 +10814,7 @@ impl FromStr for OwnerOverride {
}
/// <p>The container element for a bucket's ownership controls.</p>
#[derive(Default)]
pub struct OwnershipControls {
    /// <p>The container element for an ownership control rule.</p>
    pub rules: OwnershipControlsRules,
@@ -14058,6 +14065,7 @@ impl fmt::Debug for ServerSideEncryptionByDefault {
}
/// <p>Specifies the default server-side-encryption configuration.</p>
#[derive(Default)]
pub struct ServerSideEncryptionConfiguration {
    /// <p>Container for information about a particular server-side encryption configuration
    /// rule.</p>
@@ -14387,6 +14395,7 @@ pub type TagCount = i32;
pub type TagSet = List<Tag>;
/// <p>Container for <code>TagSet</code> elements.</p>
#[derive(Default)]
pub struct Tagging {
    /// <p>A collection for a set of tags</p>
    pub tag_set: TagSet,
@@ -15505,3 +15514,107 @@ impl fmt::Debug for WriteGetObjectResponseOutput {
}
pub type Years = i32;
#[cfg(test)]
mod tests {
    use super::*;
    fn require_default<T: Default>() {}
    #[test]
    fn test_default() {
        require_default::<AbortMultipartUploadOutput>();
        require_default::<CompleteMultipartUploadOutput>();
        require_default::<CopyObjectOutput>();
        require_default::<CreateBucketOutput>();
        require_default::<CreateMultipartUploadOutput>();
        require_default::<DeleteBucketOutput>();
        require_default::<DeleteBucketAnalyticsConfigurationOutput>();
        require_default::<DeleteBucketCorsOutput>();
        require_default::<DeleteBucketEncryptionOutput>();
        require_default::<DeleteBucketIntelligentTieringConfigurationOutput>();
        require_default::<DeleteBucketInventoryConfigurationOutput>();
        require_default::<DeleteBucketLifecycleOutput>();
        require_default::<DeleteBucketMetricsConfigurationOutput>();
        require_default::<DeleteBucketOwnershipControlsOutput>();
        require_default::<DeleteBucketPolicyOutput>();
        require_default::<DeleteBucketReplicationOutput>();
        require_default::<DeleteBucketTaggingOutput>();
        require_default::<DeleteBucketWebsiteOutput>();
        require_default::<DeleteObjectOutput>();
        require_default::<DeleteObjectTaggingOutput>();
        require_default::<DeleteObjectsOutput>();
        require_default::<DeletePublicAccessBlockOutput>();
        require_default::<GetBucketAccelerateConfigurationOutput>();
        require_default::<GetBucketAclOutput>();
        require_default::<GetBucketAnalyticsConfigurationOutput>();
        require_default::<GetBucketCorsOutput>();
        require_default::<GetBucketEncryptionOutput>();
        require_default::<GetBucketIntelligentTieringConfigurationOutput>();
        require_default::<GetBucketInventoryConfigurationOutput>();
        require_default::<GetBucketLifecycleConfigurationOutput>();
        require_default::<GetBucketLocationOutput>();
        require_default::<GetBucketLoggingOutput>();
        require_default::<GetBucketMetricsConfigurationOutput>();
        require_default::<GetBucketNotificationConfigurationOutput>();
        require_default::<GetBucketOwnershipControlsOutput>();
        require_default::<GetBucketPolicyOutput>();
        require_default::<GetBucketPolicyStatusOutput>();
        require_default::<GetBucketReplicationOutput>();
        require_default::<GetBucketRequestPaymentOutput>();
        require_default::<GetBucketTaggingOutput>();
        require_default::<GetBucketVersioningOutput>();
        require_default::<GetBucketWebsiteOutput>();
        require_default::<GetObjectOutput>();
        require_default::<GetObjectAclOutput>();
        require_default::<GetObjectAttributesOutput>();
        require_default::<GetObjectLegalHoldOutput>();
        require_default::<GetObjectLockConfigurationOutput>();
        require_default::<GetObjectRetentionOutput>();
        require_default::<GetObjectTaggingOutput>();
        require_default::<GetObjectTorrentOutput>();
        require_default::<GetPublicAccessBlockOutput>();
        require_default::<HeadBucketOutput>();
        require_default::<HeadObjectOutput>();
        require_default::<ListBucketAnalyticsConfigurationsOutput>();
        require_default::<ListBucketIntelligentTieringConfigurationsOutput>();
        require_default::<ListBucketInventoryConfigurationsOutput>();
        require_default::<ListBucketMetricsConfigurationsOutput>();
        require_default::<ListBucketsOutput>();
        require_default::<ListMultipartUploadsOutput>();
        require_default::<ListObjectVersionsOutput>();
        require_default::<ListObjectsOutput>();
        require_default::<ListObjectsV2Output>();
        require_default::<ListPartsOutput>();
        require_default::<PutBucketAccelerateConfigurationOutput>();
        require_default::<PutBucketAclOutput>();
        require_default::<PutBucketAnalyticsConfigurationOutput>();
        require_default::<PutBucketCorsOutput>();
        require_default::<PutBucketEncryptionOutput>();
        require_default::<PutBucketIntelligentTieringConfigurationOutput>();
        require_default::<PutBucketInventoryConfigurationOutput>();
        require_default::<PutBucketLifecycleConfigurationOutput>();
        require_default::<PutBucketLoggingOutput>();
        require_default::<PutBucketMetricsConfigurationOutput>();
        require_default::<PutBucketNotificationConfigurationOutput>();
        require_default::<PutBucketOwnershipControlsOutput>();
        require_default::<PutBucketPolicyOutput>();
        require_default::<PutBucketReplicationOutput>();
        require_default::<PutBucketRequestPaymentOutput>();
        require_default::<PutBucketTaggingOutput>();
        require_default::<PutBucketVersioningOutput>();
        require_default::<PutBucketWebsiteOutput>();
        require_default::<PutObjectOutput>();
        require_default::<PutObjectAclOutput>();
        require_default::<PutObjectLegalHoldOutput>();
        require_default::<PutObjectLockConfigurationOutput>();
        require_default::<PutObjectRetentionOutput>();
        require_default::<PutObjectTaggingOutput>();
        require_default::<PutPublicAccessBlockOutput>();
        require_default::<RestoreObjectOutput>();
        require_default::<SelectObjectContentOutput>();
        require_default::<UploadPartOutput>();
        require_default::<UploadPartCopyOutput>();
        require_default::<WriteGetObjectResponseOutput>();
    }
}