Unverified Commit a884a669 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Enable event streams for all services with REST protocols (#3451)

This PR enables event stream operations for the AWS `@restJson1` and
`@restXml` Smithy protocols. This will add the missing operations
reported in https://github.com/awslabs/aws-sdk-rust/issues/1075 and
https://github.com/awslabs/aws-sdk-rust/issues/1080.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 7815dc4b
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -34,3 +34,9 @@ message = "[`SdkBody`](https://docs.rs/aws-smithy-types/latest/aws_smithy_types/
references = ["smithy-rs#3365", "aws-sdk-rust#1046"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "all" }
authors = ["cayman-amzn", "rcoh"]

[[aws-sdk-rust]]
message = "Add support for Lambda's `InvokeWithResponseStreaming` and Bedrock Agent Runtime's `InvokeAgent` operations."
references = ["aws-sdk-rust#1075", "aws-sdk-rust#1080", "smithy-rs#3451"]
meta = { "breaking" = false, "bug" = false, "tada" = true }
author = "jdisanti"
+5 −5
Original line number Diff line number Diff line
@@ -5,8 +5,8 @@

aws.services=

# List of services to generate Event Stream operations for:
aws.services.eventstream.allowlist=\
    aws-sdk-transcribestreaming,\
    aws-sdk-s3,\
    aws-sdk-bedrockruntime
# List of service (comma-delimited SDK crate names, e.g. aws-sdk-kinesis) that have
# non-REST protocols that should not have their event stream operations removed by the
# RemoveEventStreamOperations transformer. All REST protocol services retain their
# event stream operations.
aws.services.eventstream.allowlist=
+9 −0
Original line number Diff line number Diff line
@@ -5,12 +5,15 @@

package software.amazon.smithy.rust.codegen.client.smithy.transformers

import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.core.util.findStreamingMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
import java.util.logging.Logger

@@ -24,6 +27,12 @@ object RemoveEventStreamOperations {
        model: Model,
        settings: ClientRustSettings,
    ): Model {
        // Only allow event stream operations in AWS REST protocols for now.
        val service = settings.getService(model)
        if (service.hasTrait<RestJson1Trait>() || service.hasTrait<RestXmlTrait>()) {
            return model
        }

        // If Event Stream is allowed in build config, then don't remove the operations
        val allowList = settings.codegenConfig.eventStreamAllowList
        if (allowList.isEmpty() || allowList.contains(settings.moduleName)) {
+82 −30
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.transformers
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenConfig
@@ -16,9 +17,28 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import java.util.Optional

internal class RemoveEventStreamOperationsTest {
    private val model =
        """
    private fun model(
        protocol: String,
        rest: Boolean,
    ): Model {
        val httpPayload =
            if (rest) {
                "@httpPayload"
            } else {
                ""
            }
        return """
            namespace test

            use aws.protocols#awsJson1_0
            use aws.protocols#restJson1
            use aws.protocols#restXml

            @$protocol
            service TestService {
                operations: [EventStream, BlobStream],
            }

            operation EventStream {
                input: StreamingInput,
            }
@@ -27,14 +47,18 @@ internal class RemoveEventStreamOperationsTest {
                input: BlobInput
            }

            @input
            structure BlobInput {
                $httpPayload
                blob: StreamingBlob
            }

            @streaming
            blob StreamingBlob

            @input
            structure StreamingInput {
                $httpPayload
                payload: Event
            }

@@ -45,13 +69,15 @@ internal class RemoveEventStreamOperationsTest {

            structure Foo {}
        """.asSmithyModel()
    }

    @Test
    fun `remove event stream ops from services that are not in the allow list`() {
        val transformed =
            RemoveEventStreamOperations.transform(
                model,
                model(protocol = "awsJson1_0", rest = false),
                testClientRustSettings(
                    service = ShapeId.from("test#TestService"),
                    codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("not-test-module")),
                ),
            )
@@ -63,12 +89,38 @@ internal class RemoveEventStreamOperationsTest {
    fun `keep event stream ops from services that are in the allow list`() {
        val transformed =
            RemoveEventStreamOperations.transform(
                model,
                model(protocol = "awsJson1_0", rest = false),
                testClientRustSettings(
                    service = ShapeId.from("test#TestService"),
                    codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("test-module")),
                ),
            )
        transformed.expectShape(ShapeId.from("test#BlobStream"))
        transformed.getShape(ShapeId.from("test#EventStream")) shouldNotBe Optional.empty<Shape>()
    }

    @Test
    fun `keep event stream ops for rest services`() {
        var transformed =
            RemoveEventStreamOperations.transform(
                model(protocol = "restJson1", rest = true),
                testClientRustSettings(
                    service = ShapeId.from("test#TestService"),
                    codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf()),
                ),
            )
        transformed.expectShape(ShapeId.from("test#BlobStream"))
        transformed.getShape(ShapeId.from("test#EventStream")) shouldNotBe Optional.empty<Shape>()

        transformed =
            RemoveEventStreamOperations.transform(
                model(protocol = "restXml", rest = true),
                testClientRustSettings(
                    service = ShapeId.from("test#TestService"),
                    codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf()),
                ),
            )
        transformed.expectShape(ShapeId.from("test#BlobStream"))
        transformed.getShape(ShapeId.from("test#EventStream")) shouldNotBe Optional.empty<Shape>()
    }
}