Unverified Commit d42ef52f authored by Nugine's avatar Nugine Committed by GitHub
Browse files

fix(s3s): complete_multipart_upload keep_alive (#348)

parent 12806eaa
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -100,6 +100,11 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
                        continue;
                    }

                    if field.position == "s3s" {
                        g!("{s3s_field_name}: None,");
                        continue;
                    }

                    let field_ty = &rust_types[field.type_.as_str()];

                    let needs_unwrap = 'unwrap: {
@@ -194,6 +199,10 @@ pub fn codegen(ops: &Operations, rust_types: &RustTypes) {
                        continue;
                    }

                    if field.position == "s3s" {
                        continue;
                    }

                    if field.option_type {
                        g!("y = y.set_{aws_field_name}(try_into_aws(x.{s3s_field_name})?);");
                    } else {
+31 −1
Original line number Diff line number Diff line
@@ -212,6 +212,8 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
                        xml_namespace_prefix: field.traits.xml_namespace_prefix().map(o),

                        is_custom_extension: field.traits.minio(),

                        custom_in_derive_debug: None,
                    };
                    fields.push(field);
                }
@@ -255,6 +257,20 @@ pub fn collect_rust_types(model: &smithy::Model, ops: &Operations) -> RustTypes
}

fn patch_types(space: &mut RustTypes) {
    // patch CompleteMultipartUploadOutput
    {
        let Some(rust::Type::Struct(ty)) = space.get_mut("CompleteMultipartUploadOutput") else { panic!() };
        ty.fields.push(rust::StructField {
            name: o("future"),
            type_: o("BoxFuture<'static, S3Result<CompleteMultipartUploadOutput>>"),
            option_type: true,
            position: o("s3s"),
            doc: Some(o("A future that resolves to the upload output or an error. This field is used to implement AWS-like keep-alive behavior.")),
            custom_in_derive_debug: Some(o("&\"<BoxFuture<'static, S3Result<CompleteMultipartUploadOutput>>>\"")),    
            ..rust::StructField::default()
        });
    }

    // patch PartNumberMarker
    // FIXME: https://github.com/awslabs/aws-sdk-rust/issues/1318
    {
@@ -324,6 +340,7 @@ fn patch_types(space: &mut RustTypes) {
            xml_namespace_uri: None,
            xml_namespace_prefix: None,
            is_custom_extension: false,
            custom_in_derive_debug: None,
        });
        ty.name = o("SelectObjectContentInput");

@@ -397,12 +414,14 @@ pub fn codegen(rust_types: &RustTypes, ops: &Operations, patch: Option<Patch>) {
        "#![allow(clippy::too_many_lines)]",
        "",
        "use super::*;",
        "use crate::error::S3Result;",
        "",
        "use std::borrow::Cow;",
        "use std::convert::Infallible;",
        "use std::fmt;",
        "use std::str::FromStr;",
        "",
        "use futures::future::BoxFuture;",
        "use stdx::default::default;",
        "use serde::{Serialize, Deserialize};",
        "",
@@ -478,7 +497,12 @@ fn codegen_struct(ty: &rust::Struct, rust_types: &RustTypes, ops: &Operations) {
    g!("fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {{");
    g!("let mut d = f.debug_struct(\"{}\");", ty.name);
    for field in &ty.fields {
        if field.option_type {
        if let Some(name) = &field.custom_in_derive_debug {
            assert!(field.option_type);
            g!("if self.{}.is_some() {{", field.name);
            g!("d.field(\"{}\", {});", field.name, name);
            g!("}}");
        } else if field.option_type {
            g!("if let Some(ref val) = self.{} {{", field.name);
            g!("d.field(\"{}\", val);", field.name);
            g!("}}");
@@ -649,6 +673,9 @@ fn can_derive_clone(ty: &rust::Struct, _rust_types: &RustTypes) -> bool {
        if field.position == "sealed" {
            return false;
        }
        if field.position == "s3s" {
            return false;
        }
        if matches!(field.type_.as_str(), "StreamingBlob" | "SelectObjectContentEventStream") {
            return false;
        }
@@ -661,6 +688,9 @@ fn can_derive_partial_eq(ty: &rust::Struct, _rust_types: &RustTypes) -> bool {
        if field.position == "sealed" {
            return false;
        }
        if field.position == "s3s" {
            return false;
        }
        if matches!(field.type_.as_str(), "StreamingBlob" | "SelectObjectContentEventStream") {
            return false;
        }
+22 −38
Original line number Diff line number Diff line
@@ -259,6 +259,7 @@ fn codegen_op_http_ser_unit(op: &Operation) {
    g!("}}");
}

#[allow(clippy::too_many_lines)]
fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {
    let output = op.output.as_str();
    let rust_type = &rust_types[output];
@@ -268,6 +269,10 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {
            codegen_op_http_ser_unit(op);
        }
        rust::Type::Struct(ty) => {
            if op.name == "CompleteMultipartUpload" {
                return; // custom implementation
            }

            if ty.fields.is_empty() {
                g!("pub fn serialize_http(_: {output}) -> S3Result<http::Response> {{");
                {
@@ -280,7 +285,7 @@ fn codegen_op_http_ser(op: &Operation, rust_types: &RustTypes) {

                assert!(ty.fields.is_empty().not());
                for field in &ty.fields {
                    assert!(["header", "metadata", "xml", "payload"].contains(&field.position.as_str()),);
                    assert!(["header", "metadata", "xml", "payload", "s3s"].contains(&field.position.as_str()),);
                }

                if op.name == "GetObject" {
@@ -690,27 +695,6 @@ fn codegen_op_http_call(op: &Operation) {
        g!("let overridden_headers = super::get_object::extract_overridden_response_headers(&s3_req)?;");
    }

    if op.name == "CompleteMultipartUpload" {
        g!("let s3 = s3.clone();");
        g!("let fut = async move {{");
        g!("let result = s3.{method}(s3_req).await;");
        g!("match result {{");
        g(["Ok(s3_resp) => {
                let mut resp = Self::serialize_http(s3_resp.output)?;
                resp.headers.extend(s3_resp.headers);
                Ok(resp)
            }"]);
        g!("Err(err) => super::serialize_error(err, true).map_err(Into::into),");
        g!("}}");
        g!("}};");
        g!("let mut resp = http::Response::with_status(http::StatusCode::OK);");
        g!(
            "http::set_keep_alive_xml_body(&mut resp, sync_wrapper::SyncFuture::new(fut), std::time::Duration::from_millis(100))?;"
        );
        g!(
            "http::add_opt_header(&mut resp, \"trailer\", Some([X_AMZ_SERVER_SIDE_ENCRYPTION_BUCKET_KEY_ENABLED.as_str(), X_AMZ_EXPIRATION.as_str(), X_AMZ_REQUEST_CHARGED.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID.as_str(), X_AMZ_SERVER_SIDE_ENCRYPTION.as_str(), X_AMZ_VERSION_ID.as_str()].join(\",\")))?;"
        );
    } else {
    g!("let result = s3.{method}(s3_req).await;");

    g([
@@ -730,7 +714,7 @@ fn codegen_op_http_call(op: &Operation) {
    }

    g!("resp.extensions.extend(s3_resp.extensions);");
    }

    g!("Ok(resp)");

    g!("}}");
+3 −1
Original line number Diff line number Diff line
@@ -78,7 +78,7 @@ pub struct Struct {
}

#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct StructField {
    pub name: String,
    pub type_: String,
@@ -103,6 +103,8 @@ pub struct StructField {
    pub xml_namespace_prefix: Option<String>,

    pub is_custom_extension: bool,

    pub custom_in_derive_debug: Option<String>,
}

#[derive(Debug, Clone)]
+1 −0
Original line number Diff line number Diff line
@@ -761,6 +761,7 @@ impl AwsConversion for s3s::dto::CompleteMultipartUploadOutput {
            ssekms_key_id: try_from_aws(x.ssekms_key_id)?,
            server_side_encryption: try_from_aws(x.server_side_encryption)?,
            version_id: try_from_aws(x.version_id)?,
            future: None,
        })
    }

Loading