Unverified Commit dfd53fb5 authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Make Rpc V2 CBOR `awsQuery` compatible (#4186)

## Description
This PR makes Rpc V2 CBOR a compatible protocol for `awsQuery` using
`awsQueryCompatible` trait, as described in the `Important` section in
[this
page](https://smithy.io/2.0/aws/protocols/aws-query-protocol.html#aws-protocols-awsquerycompatible-trait).

Previously, the implementation for `awsQueryCompatible` was tightly
coupled to `awsJson1_0`. This PR makes the implementation a bit more
abstract so that the implementation can support more target protocols
generically.

## Testing
- CI
- Made `AwsQueryCompatibleTest` parameterized tests to support RpcV2Cbor
and verified against modified cloudwatch service model (with
`awsQueryCompatible` trait applied). ~Tests for RpcV2Cbor is commented
out till https://github.com/smithy-lang/smithy-rs/pull/4185 is merged~.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.
- [x] For changes to the AWS SDK, generated SDK code, or SDK runtime
crates, I have created a changelog entry Markdown file in the
`.changelog` directory, specifying "aws-sdk-rust" in the `applies_to`
key.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 8b32dee5
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
---
applies_to:
- aws-sdk-rust
- client
authors:
- ysaito1001
references:
- smithy-rs#4186
breaking: false
new_feature: false
bug_fix: false
---
Make Rpc V2 CBOR a compatible protocol for `awsQuery` using `awsQueryCompatible` trait
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ dependencies {
    implementation(project(":codegen-core"))
    implementation(kotlin("stdlib-jdk8"))
    api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
    implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
+45 −2
Original line number Diff line number Diff line
@@ -16,13 +16,17 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryCompatible
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Ec2QueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ParseErrorMetadataParams
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader
@@ -67,7 +71,23 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) :
    ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
    override fun protocol(codegenContext: ClientCodegenContext): Protocol =
        if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) {
            AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version))
            AwsQueryCompatible(
                codegenContext, AwsJson(codegenContext, version),
                ParseErrorMetadataParams(
                    RuntimeType.smithyJson(codegenContext.runtimeConfig)
                        .resolve("deserialize::error::DeserializeError"),
                    writable {
                        rustTemplate(
                            """
                            #{parse_error_metadata}(response_body, response_headers)?
                            """,
                            "parse_error_metadata" to
                                RuntimeType.jsonErrors(codegenContext.runtimeConfig)
                                    .resolve("parse_error_metadata"),
                        )
                    },
                ),
            )
        } else {
            AwsJson(codegenContext, version)
        }
@@ -122,10 +142,33 @@ class ClientRestXmlFactory(
}

class ClientRpcV2CborFactory : ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
    override fun protocol(codegenContext: ClientCodegenContext): Protocol = RpcV2Cbor(codegenContext)
    override fun protocol(codegenContext: ClientCodegenContext): Protocol =
        if (compatibleWithAwsQuery(codegenContext.serviceShape)) {
            AwsQueryCompatible(
                codegenContext, RpcV2Cbor(codegenContext),
                ParseErrorMetadataParams(
                    RuntimeType.smithyCbor(codegenContext.runtimeConfig)
                        .resolve("decode::DeserializeError"),
                    writable {
                        rustTemplate(
                            """
                            #{parse_error_metadata}(_response_status, response_headers, response_body)?
                            """,
                            "parse_error_metadata" to
                                RuntimeType.cborErrors(codegenContext.runtimeConfig)
                                    .resolve("parse_error_metadata"),
                        )
                    },
                ),
            )
        } else {
            RpcV2Cbor(codegenContext)
        }

    override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
        OperationGenerator(codegenContext, protocol(codegenContext))

    override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT

    private fun compatibleWithAwsQuery(serviceShape: ServiceShape) = serviceShape.hasTrait<AwsQueryCompatibleTrait>()
}
+88 −24
Original line number Diff line number Diff line
@@ -5,26 +5,89 @@

package software.amazon.smithy.rust.codegen.client.smithy.protocols

import org.junit.jupiter.api.Test
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.testModule
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
import software.amazon.smithy.rust.codegen.core.util.letIf
import java.util.stream.Stream

data class AwsQueryCompatibleTestInput(
    val protocolAnnotation: String,
    val payload: Writable,
)

class AwsQueryCompatibleTestInputProvider : ArgumentsProvider {
    private fun jsonStringToBytesArrayString(jsonString: String): String {
        val jsonMapper = ObjectMapper()
        val cborMapper = ObjectMapper(CBORFactory())
        // Parse JSON string to a generic type.
        val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
        // Convert the parsed data to CBOR.
        val bytes = cborMapper.writeValueAsBytes(jsonData)
        return bytes
            .joinToString(
                prefix = "&[",
                postfix = "]",
                transform = { "0x${it.toUByte().toString(16).padStart(2, '0')}u8" },
            )
    }

    override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> =
        listOf(
            AwsQueryCompatibleTestInput(
                "@awsJson1_0",
                writable {
                    rust(
                        """
                        r##"{
                            "__type": "com.amazonaws.sqs##QueueDoesNotExist",
                            "message": "Some user-visible message"
                        }"##
                        """,
                    )
                },
            ),
            AwsQueryCompatibleTestInput(
                "@rpcv2Cbor",
                writable {
                    val bytesArray =
                        jsonStringToBytesArrayString(
                            """
                            {
                                "__type": "com.amazonaws.sqs#QueueDoesNotExist",
                                "message": "Some user-visible message"
                            }
                            """,
                        )
                    rust("#T::from_static($bytesArray)", RuntimeType.Bytes)
                },
            ),
        ).map { Arguments.of(it) }.stream()
}

class AwsQueryCompatibleTest {
    companion object {
        const val prologue = """
            namespace test
            use smithy.protocols#rpcv2Cbor
            use aws.protocols#awsJson1_0
            use aws.protocols#awsQueryCompatible
            use aws.protocols#awsQueryError
        """

        const val awsjson10Trait = "@awsJson1_0"
        const val awsQueryCompatibleTrait = "@awsQueryCompatible"

        fun testService(withAwsQueryError: Boolean = true) =
@@ -63,10 +126,13 @@ class AwsQueryCompatibleTest {
            }
    }

    @Test
    fun `aws-query-compatible json with aws query error should allow for retrieving error code and type from custom header`() {
    @ParameterizedTest
    @ArgumentsSource(AwsQueryCompatibleTestInputProvider::class)
    fun `aws-query-compatible json with aws query error should allow for retrieving error code and type from custom header`(
        testInput: AwsQueryCompatibleTestInput,
    ) {
        val model =
            (prologue + awsQueryCompatibleTrait + awsjson10Trait + testService()).asSmithyModel(
            (prologue + awsQueryCompatibleTrait + testInput.protocolAnnotation + testService()).asSmithyModel(
                smithyVersion = "2",
            )
        clientIntegrationTest(model) { context, rustCrate ->
@@ -82,12 +148,7 @@ class AwsQueryCompatibleTest {
                                )
                                .status(400)
                                .body(
                                    #{SdkBody}::from(
                                        r##"{
                                            "__type": "com.amazonaws.sqs##QueueDoesNotExist",
                                            "message": "Some user-visible message"
                                        }"##
                                    )
                                    #{SdkBody}::from(#{payload:W})
                                )
                                .unwrap()
                        };
@@ -105,6 +166,7 @@ class AwsQueryCompatibleTest {
                        assert_eq!(#{Some}("Sender"), error.meta().extra("type"));
                        """,
                        *RuntimeType.preludeScope,
                        "payload" to testInput.payload,
                        "SdkBody" to RuntimeType.sdkBody(context.runtimeConfig),
                        "infallible_client_fn" to
                            CargoDependency.smithyHttpClientTestUtil(context.runtimeConfig)
@@ -116,10 +178,13 @@ class AwsQueryCompatibleTest {
        }
    }

    @Test
    fun `aws-query-compatible json without aws query error should allow for retrieving error code from payload`() {
    @ParameterizedTest
    @ArgumentsSource(AwsQueryCompatibleTestInputProvider::class)
    fun `aws-query-compatible json without aws query error should allow for retrieving error code from payload`(
        testInput: AwsQueryCompatibleTestInput,
    ) {
        val model =
            (prologue + awsQueryCompatibleTrait + awsjson10Trait + testService(withAwsQueryError = false)).asSmithyModel(
            (prologue + awsQueryCompatibleTrait + testInput.protocolAnnotation + testService(withAwsQueryError = false)).asSmithyModel(
                smithyVersion = "2",
            )
        clientIntegrationTest(model) { context, rustCrate ->
@@ -131,12 +196,7 @@ class AwsQueryCompatibleTest {
                            #{http_1x}::Response::builder()
                                .status(400)
                                .body(
                                    #{SdkBody}::from(
                                        r##"{
                                            "__type": "com.amazonaws.sqs##QueueDoesNotExist",
                                            "message": "Some user-visible message"
                                        }"##,
                                    )
                                    #{SdkBody}::from(#{payload:W})
                                )
                                .unwrap()
                        };
@@ -152,6 +212,7 @@ class AwsQueryCompatibleTest {
                        """,
                        *RuntimeType.preludeScope,
                        "SdkBody" to RuntimeType.sdkBody(context.runtimeConfig),
                        "payload" to testInput.payload,
                        "infallible_client_fn" to
                            CargoDependency.smithyHttpClientTestUtil(context.runtimeConfig)
                                .toType().resolve("test_util::infallible_client_fn"),
@@ -162,10 +223,13 @@ class AwsQueryCompatibleTest {
        }
    }

    @Test
    fun `request header should include x-amzn-query-mode when the service has the awsQueryCompatible trait`() {
    @ParameterizedTest
    @ArgumentsSource(AwsQueryCompatibleTestInputProvider::class)
    fun `request header should include x-amzn-query-mode when the service has the awsQueryCompatible trait`(
        testInput: AwsQueryCompatibleTestInput,
    ) {
        val model =
            (prologue + awsQueryCompatibleTrait + awsjson10Trait + testService()).asSmithyModel(
            (prologue + awsQueryCompatibleTrait + testInput.protocolAnnotation + testService()).asSmithyModel(
                smithyVersion = "2",
            )
        clientIntegrationTest(model) { context, rustCrate ->
+34 −31
Original line number Diff line number Diff line
@@ -10,51 +10,56 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator

class AwsQueryCompatibleHttpBindingResolver(
    private val awsQueryBindingResolver: AwsQueryBindingResolver,
    private val awsJsonHttpBindingResolver: AwsJsonHttpBindingResolver,
    private val targetProtocolHttpBinding: HttpBindingResolver,
) : HttpBindingResolver {
    override fun httpTrait(operationShape: OperationShape): HttpTrait =
        awsJsonHttpBindingResolver.httpTrait(operationShape)
        targetProtocolHttpBinding.httpTrait(operationShape)

    override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
        awsJsonHttpBindingResolver.requestBindings(operationShape)
        targetProtocolHttpBinding.requestBindings(operationShape)

    override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
        awsJsonHttpBindingResolver.responseBindings(operationShape)
        targetProtocolHttpBinding.responseBindings(operationShape)

    override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> =
        awsJsonHttpBindingResolver.errorResponseBindings(errorShape)
        targetProtocolHttpBinding.errorResponseBindings(errorShape)

    override fun errorCode(errorShape: ToShapeId): String = awsQueryBindingResolver.errorCode(errorShape)

    override fun requestContentType(operationShape: OperationShape): String =
        awsJsonHttpBindingResolver.requestContentType(operationShape)
    override fun requestContentType(operationShape: OperationShape): String? =
        targetProtocolHttpBinding.requestContentType(operationShape)

    override fun responseContentType(operationShape: OperationShape): String =
        awsJsonHttpBindingResolver.requestContentType(operationShape)
    override fun responseContentType(operationShape: OperationShape): String? =
        targetProtocolHttpBinding.requestContentType(operationShape)

    override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
        awsJsonHttpBindingResolver.eventStreamMessageContentType(memberShape)
        targetProtocolHttpBinding.eventStreamMessageContentType(memberShape)

    override fun handlesEventStreamInitialRequest(shape: Shape): Boolean =
        awsJsonHttpBindingResolver.handlesEventStreamInitialRequest(shape)
        targetProtocolHttpBinding.handlesEventStreamInitialRequest(shape)

    override fun handlesEventStreamInitialResponse(shape: Shape): Boolean =
        awsJsonHttpBindingResolver.handlesEventStreamInitialResponse(shape)
        targetProtocolHttpBinding.handlesEventStreamInitialResponse(shape)
}

data class ParseErrorMetadataParams(
    val deserializeErrorType: RuntimeType,
    val innerParseErrorMetadata: Writable,
)

class AwsQueryCompatible(
    val codegenContext: CodegenContext,
    private val awsJson: AwsJson,
    private val targetProtocol: Protocol,
    private val params: ParseErrorMetadataParams,
) : Protocol {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val errorScope =
@@ -62,33 +67,29 @@ class AwsQueryCompatible(
            "Bytes" to RuntimeType.Bytes,
            "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
            "Headers" to RuntimeType.headers(runtimeConfig),
            "JsonError" to
                CargoDependency.smithyJson(runtimeConfig).toType()
                    .resolve("deserialize::error::DeserializeError"),
            "aws_query_compatible_errors" to RuntimeType.awsQueryCompatibleErrors(runtimeConfig),
            "json_errors" to RuntimeType.jsonErrors(runtimeConfig),
            *RuntimeType.preludeScope,
        )

    override val httpBindingResolver: HttpBindingResolver =
        AwsQueryCompatibleHttpBindingResolver(
            AwsQueryBindingResolver(codegenContext.model),
            AwsJsonHttpBindingResolver(codegenContext.model, awsJson.version, codegenContext.target == CodegenTarget.SERVER),
            targetProtocol.httpBindingResolver,
        )

    override val defaultTimestampFormat = awsJson.defaultTimestampFormat
    override val defaultTimestampFormat = targetProtocol.defaultTimestampFormat

    override fun structuredDataParser(): StructuredDataParserGenerator = awsJson.structuredDataParser()
    override fun structuredDataParser(): StructuredDataParserGenerator = targetProtocol.structuredDataParser()

    override fun structuredDataSerializer(): StructuredDataSerializerGenerator = awsJson.structuredDataSerializer()
    override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
        targetProtocol.structuredDataSerializer()

    override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
        ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName ->
            rustTemplate(
                """
                pub fn $fnName(_response_status: u16, response_headers: &#{Headers}, response_body: &[u8]) -> #{Result}<#{ErrorMetadataBuilder}, #{JsonError}> {
                    let mut builder =
                        #{json_errors}::parse_error_metadata(response_body, response_headers)?;
                pub fn $fnName(_response_status: u16, response_headers: &#{Headers}, response_body: &[u8]) -> #{Result}<#{ErrorMetadataBuilder}, #{DeserializeError}> {
                    let mut builder = #{parse_error_metadata};
                    if let Some((error_code, error_type)) =
                        #{aws_query_compatible_errors}::parse_aws_query_compatible_error(response_headers)
                    {
@@ -99,15 +100,17 @@ class AwsQueryCompatible(
                }
                """,
                *errorScope,
                "DeserializeError" to params.deserializeErrorType,
                "parse_error_metadata" to params.innerParseErrorMetadata,
            )
        }

    override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
        awsJson.parseEventStreamErrorMetadata(operationShape)
        targetProtocol.parseEventStreamErrorMetadata(operationShape)

    override fun additionalRequestHeaders(operationShape: OperationShape): List<Pair<String, String>> =
        targetProtocol.additionalRequestHeaders(operationShape) +
            listOf(
            "x-amz-target" to "${codegenContext.serviceShape.id.name}.${operationShape.id.name}",
                "x-amzn-query-mode" to "true",
            )
}