Unverified Commit fd8b4f6c authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Fix protocol selection behavior in `ClientProtocolLoader` (#4165)

## Description
The bugs include
- The default SDK-supported protocols (`DefaultProtocols`) were listed
in an incorrect priority order.
- Protocol resolution logic incorrectly iterated over service-applied
protocols, which is returned by `getProtocols()` whose result may not
reflect the intended priority.

This PR addresses these issues.

## Testing
- Existing CI
- `ClientProtocolLoaderTest.kt`

## Checklist
- [x] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.
- [x] For changes to the AWS SDK, generated SDK code, or SDK runtime
crates, I have created a changelog entry Markdown file in the
`.changelog` directory, specifying "aws-sdk-rust" in the `applies_to`
key.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent dc480dae
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
---
applies_to:
- client
- aws-sdk-rust
authors:
- ysaito1001
references:
- smithy-rs#4165
breaking: false
new_feature: false
bug_fix: true
---
Fix default supported protocols incorrectly ordered in `ClientProtocolLoader`.
+3 −3
Original line number Diff line number Diff line
@@ -37,13 +37,13 @@ class ClientProtocolLoader(supportedProtocols: ProtocolMap<OperationGenerator, C
    companion object {
        val DefaultProtocols =
            mapOf(
                Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
                AwsJson1_0Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json10),
                AwsJson1_1Trait.ID to ClientAwsJsonFactory(AwsJsonVersion.Json11),
                AwsQueryTrait.ID to ClientAwsQueryFactory(),
                Ec2QueryTrait.ID to ClientEc2QueryFactory(),
                RestJson1Trait.ID to ClientRestJsonFactory(),
                RestXmlTrait.ID to ClientRestXmlFactory(),
                Rpcv2CborTrait.ID to ClientRpcV2CborFactory(),
                AwsQueryTrait.ID to ClientAwsQueryFactory(),
                Ec2QueryTrait.ID to ClientEc2QueryFactory(),
            )
        val Default = ClientProtocolLoader(DefaultProtocols)
    }
+123 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.client.smithy.protocols

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.ArgumentsProvider
import org.junit.jupiter.params.provider.ArgumentsSource
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader.Companion.DefaultProtocols
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import java.util.stream.Stream

data class TestCase(
    val supportedProtocols: ProtocolMap<OperationGenerator, ClientCodegenContext>,
    val model: Model,
    val resolvedProtocol: String?,
)

class ClientProtocolLoaderTest {
    @Test
    fun `test priority order of default supported protocols`() {
        val expectedOrder =
            listOf(
                Rpcv2CborTrait.ID,
                AwsJson1_0Trait.ID,
                AwsJson1_1Trait.ID,
                RestJson1Trait.ID,
                RestXmlTrait.ID,
                AwsQueryTrait.ID,
                Ec2QueryTrait.ID,
            )
        assertEquals(expectedOrder, DefaultProtocols.keys.toList())
    }

    // Although the test function name appears generic, its purpose is to verify whether
    // the RPCv2Cbor protocol is selected based on specific contexts.
    @ParameterizedTest
    @ArgumentsSource(ProtocolSelectionTestCaseProvider::class)
    fun `should resolve expected protocol`(testCase: TestCase) {
        val protocolLoader = ClientProtocolLoader(testCase.supportedProtocols)
        val serviceShape = testCase.model.expectShape(ShapeId.from("test#TestService"), ServiceShape::class.java)
        if (testCase.resolvedProtocol.isNullOrEmpty()) {
            assertThrows<CodegenException> {
                protocolLoader.protocolFor(testCase.model, serviceShape)
            }
        } else {
            val actual = protocolLoader.protocolFor(testCase.model, serviceShape).first.name
            assertEquals(testCase.resolvedProtocol, actual)
        }
    }
}

class ProtocolSelectionTestCaseProvider : ArgumentsProvider {
    override fun provideArguments(p0: ExtensionContext?): Stream<out Arguments> {
        val protocolsWithoutRpcv2Cbor = LinkedHashMap(DefaultProtocols)
        protocolsWithoutRpcv2Cbor.remove(Rpcv2CborTrait.ID)

        return arrayOf(
            TestCase(DefaultProtocols, model(listOf("rpcv2Cbor", "awsJson1_0")), "rpcv2Cbor"),
            TestCase(DefaultProtocols, model(listOf("rpcv2Cbor")), "rpcv2Cbor"),
            TestCase(DefaultProtocols, model(listOf("rpcv2Cbor", "awsJson1_0", "awsQuery")), "rpcv2Cbor"),
            TestCase(DefaultProtocols, model(listOf("awsJson1_0", "awsQuery")), "awsJson1_0"),
            TestCase(DefaultProtocols, model(listOf("awsQuery")), "awsQuery"),
            TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor", "awsJson1_0")), "awsJson1_0"),
            TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor")), null),
            TestCase(protocolsWithoutRpcv2Cbor, model(listOf("rpcv2Cbor", "awsJson1_0", "awsQuery")), "awsJson1_0"),
            TestCase(protocolsWithoutRpcv2Cbor, model(listOf("awsJson1_0", "awsQuery")), "awsJson1_0"),
            TestCase(protocolsWithoutRpcv2Cbor, model(listOf("awsQuery")), "awsQuery"),
        ).map { Arguments.of(it) }.stream()
    }

    private fun model(protocols: List<String>) =
        (
            """
            namespace test
            """ + renderProtocols(protocols) +
                """
                @xmlNamespace(uri: "http://test.com") // required for @awsQuery
                service TestService {
                    version: "1.0.0"
                }
                """
        ).asSmithyModel(smithyVersion = "2.0")
}

private fun renderProtocols(protocols: List<String>): String {
    val (rpcProtocols, awsProtocols) = protocols.partition { it == "rpcv2Cbor" }

    val uses =
        buildList {
            rpcProtocols.forEach { add("use smithy.protocols#$it") }
            awsProtocols.forEach { add("use aws.protocols#$it") }
        }.joinToString("\n")

    val annotations = protocols.joinToString("\n") { "@$it" }

    return """
        $uses

        $annotations
    """
}
+5 −3
Original line number Diff line number Diff line
@@ -18,11 +18,13 @@ open class ProtocolLoader<T, C : CodegenContext>(private val supportedProtocols:
        model: Model,
        serviceShape: ServiceShape,
    ): Pair<ShapeId, ProtocolGeneratorFactory<T, C>> {
        val protocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
        val serviceProtocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
        val matchingProtocols =
            protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } }
            supportedProtocols.mapNotNull { (protocolId, factory) ->
                serviceProtocols[protocolId]?.let { protocolId to factory }
            }
        if (matchingProtocols.isEmpty()) {
            throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}")
            throw CodegenException("No matching protocol — service offers: ${serviceProtocols.keys}. We offer: ${supportedProtocols.keys}")
        }
        return matchingProtocols.first()
    }