Unverified Commit 2db1c34d authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Allow <Error> to trigger error handling for S3 (#2958)

## Motivation and Context
- https://github.com/awslabs/aws-sdk-rust/issues/873

## Description
Add a customization for S3 so that if `<Error>` is the root element, we
trigger the error parsing flow

## Testing
- Added an integration test

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 9f396079
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -55,6 +55,12 @@ references = ["smithy-rs#2948"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "Velfi"

[[aws-sdk-rust]]
message = "Correctly identify HTTP 200 responses from S3 with `<Error>` as the root Element as errors. **Note**: This a behavior change and will change the error type returned by the SDK in some cases."
references = ["smithy-rs#2958", "aws-sdk-rust#873"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "rcoh"

[[aws-sdk-rust]]
message = "Allow `no_credentials` to be used with all S3 operations."
references = ["smithy-rs#2955", "aws-sdk-rust#878"]
+33 −3
Original line number Diff line number Diff line
@@ -23,7 +23,9 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientRestXmlFactory
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -106,6 +108,34 @@ class S3Decorator : ClientCodegenDecorator {
        )
    }

    override fun operationCustomizations(
        codegenContext: ClientCodegenContext,
        operation: OperationShape,
        baseCustomizations: List<OperationCustomization>,
    ): List<OperationCustomization> {
        return baseCustomizations + object : OperationCustomization() {
            override fun section(section: OperationSection): Writable {
                return writable {
                    when (section) {
                        is OperationSection.BeforeParseResponse -> {
                            section.body?.also { body ->
                                rustTemplate(
                                    """
                                    if matches!(#{errors}::body_is_error($body), Ok(true)) {
                                        ${section.forceError} = true;
                                    }
                                    """,
                                    "errors" to RuntimeType.unwrappedXmlErrors(codegenContext.runtimeConfig),
                                )
                            }
                        }
                        else -> {}
                    }
                }
            }
        }
    }

    private fun isInInvalidXmlRootAllowList(shape: Shape): Boolean {
        return shape.isStructureShape && invalidXmlRootAllowList.contains(shape.id)
    }
@@ -115,7 +145,7 @@ class FilterEndpointTests(
    private val testFilter: (EndpointTestCase) -> EndpointTestCase? = { a -> a },
    private val operationInputFilter: (EndpointTestOperationInput) -> EndpointTestOperationInput? = { a -> a },
) {
    fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
    private fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
        val filteredTests = endpointTests.mapNotNull { test -> testFilter(test) }
        return filteredTests.map { test ->
            val operationInputs = test.operationInputs
@@ -123,7 +153,7 @@ class FilterEndpointTests(
        }
    }

    fun transform(model: Model) = ModelTransformer.create().mapTraits(model) { _, trait ->
    fun transform(model: Model): Model = ModelTransformer.create().mapTraits(model) { _, trait ->
        when (trait) {
            is EndpointTestsTrait -> EndpointTestsTrait.builder().testCases(updateEndpointTests(trait.testCases))
                .version(trait.version).build()
@@ -135,7 +165,7 @@ class FilterEndpointTests(

// TODO(P96049742): This model transform may need to change depending on if and how the S3 model is updated.
private class AddOptionalAuth {
    fun transform(model: Model) = ModelTransformer.create().mapShapes(model) { shape ->
    fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
        // Add @optionalAuth to all S3 operations
        if (shape is OperationShape && !shape.hasTrait<OptionalAuthTrait>()) {
            shape.toBuilder()
+40 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_sdk_s3::Client;
use aws_smithy_client::test_connection::infallible_connection_fn;
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_types::region::Region;
use aws_types::SdkConfig;

const ERROR_RESPONSE: &str = r#"<?xml version="1.0" encoding="UTF-8"?>
        <Error>
            <Code>SlowDown</Code>
            <Message>Please reduce your request rate.</Message>
            <RequestId>K2H6N7ZGQT6WHCEG</RequestId>
            <HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
        </Error>
"#;

#[tokio::test]
async fn status_200_errors() {
    let conn = infallible_connection_fn(|_req| http::Response::new(SdkBody::from(ERROR_RESPONSE)));
    let sdk_config = SdkConfig::builder()
        .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
        .region(Region::new("us-west-4"))
        .http_connector(conn)
        .build();
    let client = Client::new(&sdk_config);
    let error = client
        .delete_objects()
        .bucket("bucket")
        .send()
        .await
        .expect_err("should fail");
    assert_eq!(error.into_service_error().code(), Some("SlowDown"));
}
+8 −0
Original line number Diff line number Diff line
@@ -48,6 +48,14 @@ sealed class OperationSection(name: String) : Section(name) {
    data class BeforeParseResponse(
        override val customizations: List<OperationCustomization>,
        val responseName: String,
        /**
         * Name of the `force_error` variable. Set this to true to trigger error parsing.
         */
        val forceError: String,
        /**
         * When set, the name of the response body data field
         */
        val body: String?,
    ) : OperationSection("BeforeParseResponse")

    /**
+8 −4
Original line number Diff line number Diff line
@@ -94,10 +94,12 @@ class ResponseDeserializerGenerator(
        rustTemplate(
            """
            fn deserialize_streaming(&self, response: &mut #{HttpResponse}) -> #{Option}<#{OutputOrError}> {
                ##[allow(unused_mut)]
                let mut force_error = false;
                #{BeforeParseResponse}

                // If this is an error, defer to the non-streaming parser
                if !response.status().is_success() && response.status().as_u16() != $successCode {
                if (!response.status().is_success() && response.status().as_u16() != $successCode) || force_error {
                    return #{None};
                }
                #{Some}(#{type_erase_result}(#{parse_streaming_response}(response)))
@@ -106,7 +108,7 @@ class ResponseDeserializerGenerator(
            *codegenScope,
            "parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, customizations),
            "BeforeParseResponse" to writable {
                writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
                writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", body = null))
            },
        )
    }
@@ -136,8 +138,10 @@ class ResponseDeserializerGenerator(
            let (success, status) = (response.status().is_success(), response.status().as_u16());
            let headers = response.headers();
            let body = response.body().bytes().expect("body loaded");
            ##[allow(unused_mut)]
            let mut force_error = false;
            #{BeforeParseResponse}
            let parse_result = if !success && status != $successCode {
            let parse_result = if !success && status != $successCode || force_error {
                #{parse_error}(status, headers, body)
            } else {
                #{parse_response}(status, headers, body)
@@ -148,7 +152,7 @@ class ResponseDeserializerGenerator(
            "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations),
            "parse_response" to parserGenerator.parseResponseFn(operationShape, customizations),
            "BeforeParseResponse" to writable {
                writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
                writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", "body"))
            },
        )
    }