Unverified Commit 8075c771 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Customize S3's GetBucketLocation call to correctly parse the response (#516)

* Customize S3's GetBucketLocation call to correctly parse the response

* Move the regression test into s3-tests.smithy
parent c75a58f7
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -6,8 +6,12 @@
package software.amazon.smithy.rustsdk.customize.s3

import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
@@ -24,6 +28,7 @@ import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXmlFactory
import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rustsdk.AwsRuntimeType

/**
@@ -32,6 +37,7 @@ import software.amazon.smithy.rustsdk.AwsRuntimeType
class S3Decorator : RustCodegenDecorator {
    override val name: String = "S3ExtendedError"
    override val order: Byte = 0

    private fun applies(serviceId: ShapeId) =
        serviceId == ShapeId.from("com.amazonaws.s3#AmazonS3")

@@ -53,6 +59,20 @@ class S3Decorator : RustCodegenDecorator {
            it + S3PubUse()
        }
    }

    override fun transformModel(service: ServiceShape, model: Model): Model {
        return model.letIf(applies(service.id)) {
            ModelTransformer.create().mapShapes(model) { shape ->
                // Apply the S3UnwrappedXmlOutput customization to GetBucketLocation (more
                // details on the S3UnwrappedXmlOutputTrait)
                if (shape is StructureShape && shape.id == ShapeId.from("com.amazonaws.s3#GetBucketLocationOutput")) {
                    shape.toBuilder().addTrait(S3UnwrappedXmlOutputTrait()).build()
                } else {
                    shape
                }
            }
        }
    }
}

class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) {
+13 −0
Original line number Diff line number Diff line
@@ -21,3 +21,16 @@ apply NotFound @httpResponseTests([
        }
    }
])

apply GetBucketLocation @httpResponseTests([
    {
        id: "GetBucketLocation",
        documentation: "This test case validates https://github.com/awslabs/aws-sdk-rust/issues/116",
        code: 200,
        body: "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<LocationConstraint xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">us-west-2</LocationConstraint>",
        params: {
            "LocationConstraint": "us-west-2"
        },
        protocol: "aws.protocols#restXml"
    }
])
+0 −4
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */
+32 −2
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -191,9 +192,13 @@ class XmlBindingTraitParserGenerator(
                    *codegenScope
                )
                val context = OperationWrapperContext(operationShape, shapeName, xmlError)
                if (outputShape.hasTrait<S3UnwrappedXmlOutputTrait>()) {
                    unwrappedResponseParser("builder", "decoder", "start_el", outputShape.members())
                } else {
                    writeOperationWrapper(context) { tagName ->
                        parseStructureInner(members, builder = "builder", Ctx(tag = tagName, accum = null))
                    }
                }
                rust("Ok(builder)")
            }
        }
@@ -233,6 +238,31 @@ class XmlBindingTraitParserGenerator(
        TODO("Document shapes are not supported by rest XML")
    }

    private fun RustWriter.unwrappedResponseParser(
        builder: String,
        decoder: String,
        element: String,
        members: Collection<MemberShape>
    ) {
        check(members.size == 1) {
            "The S3UnwrappedXmlOutputTrait is only allowed on structs with exactly one member"
        }
        val member = members.first()
        rustBlock("match $element") {
            case(member) {
                val temp = safeName()
                withBlock("let $temp =", ";") {
                    parseMember(
                        member,
                        Ctx(tag = decoder, accum = "$builder.${symbolProvider.toMemberName(member)}.take()")
                    )
                }
                rust("$builder = $builder.${member.setterName()}($temp);")
            }
            rustTemplate("_ => return Err(#{XmlError}::custom(\"expected ${member.xmlName()} tag\"))", *codegenScope)
        }
    }

    /**
     * Update a structure builder based on the [members], specifying where to find each member (document vs. attributes)
     */
+34 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.traits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait

/**
 * S3's GetBucketLocation response shape can't be represented with Smithy's restXml protocol
 * without customization. We add this trait to the S3 model at codegen time so that a different
 * code path is taken in the XML deserialization codegen to generate code that parses the S3
 * response shape correctly.
 *
 * From what the S3 model states, the generated parser would expect:
 * ```
 * <LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
 *     <LocationConstraint>us-west-2</LocationConstraint>
 * </LocationConstraint>
 * ```
 *
 * But S3 actually responds with:
 * ```
 * <LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">us-west-2</LocationConstraint>
 * ```
 */
class S3UnwrappedXmlOutputTrait : AnnotationTrait(ID, Node.objectNode()) {
    companion object {
        val ID = ShapeId.from("smithy.api.internal#s3UnwrappedXmlOutputTrait")
    }
}