Unverified Commit 36b50b3d authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Add client-support for RPC v2 CBOR (#3767)

## Motivation and Context
Follow-up on https://github.com/smithy-lang/smithy-rs/pull/2544 to add
client-side support for the protocol

## Description
The client implementation mainly focuses on a sub-section
[Requests](https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#requests)
in the spec. To that end, this PR addresses `TODO` for the client to
fill in the blanks and includes additional adjustments/refactoring to
pass client protocol tests.

## Testing
- Existing tests in CI
- Upstream protocol test `rpcv2Cbor`
- Our handwritten protocol test `rpcv2Cbor-extras.smithy`

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 50148e69
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -22,3 +22,9 @@ message = "Fix incorrect redaction of `@sensitive` types in maps and lists."
references = ["smithy-rs#3765",  "smithy-rs#3757"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "landonxjames"

[[smithy-rs]]
message = "Fix client error correction to properly parse structure members that target a `Union` containing that structure recursively."
references = ["smithy-rs#3767"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "ysaito1001"
+7 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ val workingDirUnderBuildDir = "smithyprojections/codegen-client-test/"
dependencies {
    implementation(project(":codegen-client"))
    implementation("software.amazon.smithy:smithy-aws-protocol-tests:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-tests:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
}
@@ -72,6 +73,12 @@ val allCodegenTests = listOf(
    ClientTest("aws.protocoltests.restxml#RestXml", "rest_xml", addMessageToErrors = false),
    ClientTest("aws.protocoltests.query#AwsQuery", "aws_query", addMessageToErrors = false),
    ClientTest("aws.protocoltests.ec2#AwsEc2", "ec2_query", addMessageToErrors = false),
    ClientTest("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol", "rpcv2Cbor"),
    ClientTest(
        "smithy.protocoltests.rpcv2Cbor#RpcV2CborService",
        "rpcv2Cbor_extras",
        dependsOn = listOf("rpcv2Cbor-extras.smithy")
    ),
    ClientTest(
        "aws.protocoltests.restxml.xmlns#RestXmlWithNamespace",
        "rest_xml_namespace",
+12 −1
Original line number Diff line number Diff line
@@ -87,7 +87,18 @@ private fun ClientCodegenContext.errorCorrectedDefault(member: MemberShape): Wri

            target is TimestampShape -> instantiator.instantiate(target, Node.from(0)).some()(this)
            target is BlobShape -> instantiator.instantiate(target, Node.from("")).some()(this)
            target is UnionShape -> rust("Some(#T::Unknown)", targetSymbol)
            target is UnionShape ->
                rustTemplate(
                    "Some(#{unknown})", *preludeScope,
                    "unknown" to
                        writable {
                            if (memberSymbol.isRustBoxed()) {
                                rust("Box::new(#T::Unknown)", targetSymbol)
                            } else {
                                rust("#T::Unknown", targetSymbol)
                            }
                        },
                )
        }
    }
}
+9 −5
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.Proto
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.REST_JSON
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.RPC_V2_CBOR
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
@@ -78,6 +79,8 @@ class ClientProtocolTestGenerator(
                FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"),
                FailingTest.RequestTest(REST_JSON, "RestJsonClientPopulatesDefaultValuesInInput"),
                FailingTest.RequestTest(REST_JSON, "RestJsonClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
                FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientPopulatesDefaultValuesInInput"),
                FailingTest.RequestTest(RPC_V2_CBOR, "RpcV2CborClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
            )

        private val BrokenTests:
@@ -268,6 +271,7 @@ class ClientProtocolTestGenerator(
            """,
            RT.sdkBody(runtimeConfig = rc),
        )
        val mediaType = testCase.bodyMediaType.orNull()
        rustTemplate(
            """
            use #{DeserializeResponse};
@@ -280,19 +284,19 @@ class ClientProtocolTestGenerator(
            let parsed = de.deserialize_streaming(&mut http_response);
            let parsed = parsed.unwrap_or_else(|| {
                let http_response = http_response.map(|body| {
                    #{SdkBody}::from(#{copy_from_slice}(body.bytes().unwrap()))
                    #{SdkBody}::from(#{copy_from_slice}(&#{decode_body_data}(body.bytes().unwrap(), #{MediaType}::from(${(mediaType ?: "unknown").dq()}))))
                });
                de.deserialize_nonstreaming(&http_response)
            });
            """,
            "copy_from_slice" to RT.Bytes.resolve("copy_from_slice"),
            "SharedResponseDeserializer" to
                RT.smithyRuntimeApiClient(rc)
                    .resolve("client::ser_de::SharedResponseDeserializer"),
            "Operation" to codegenContext.symbolProvider.toSymbol(operationShape),
            "decode_body_data" to RT.protocolTest(rc, "decode_body_data"),
            "DeserializeResponse" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::DeserializeResponse"),
            "MediaType" to RT.protocolTest(rc, "MediaType"),
            "Operation" to codegenContext.symbolProvider.toSymbol(operationShape),
            "RuntimePlugin" to RT.runtimePlugin(rc),
            "SdkBody" to RT.sdkBody(rc),
            "SharedResponseDeserializer" to RT.smithyRuntimeApiClient(rc).resolve("client::ser_de::SharedResponseDeserializer"),
        )
        if (expectedShape.hasTrait<ErrorTrait>()) {
            val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape)
+2 −5
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.findStreamingMember
@@ -125,10 +124,8 @@ class RequestSerializerGenerator(
        )
    }

    private fun needsContentLength(operationShape: OperationShape): Boolean {
        return protocol.httpBindingResolver.requestBindings(operationShape)
            .any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD }
    }
    private fun needsContentLength(operationShape: OperationShape): Boolean =
        protocol.needsRequestContentLength(operationShape)

    private fun createHttpRequest(operationShape: OperationShape): Writable =
        writable {
Loading