Unverified Commit 99c5a575 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Enable the invoke_model_with_response_stream API (#3031)

This API is a streaming API—I tested it manually, an E2E test an example
are forthcoming. I also removed the unused `config` service from the
smoketest and added bedrock-runtime.

## Motivation and Context
This API is suppressed by default because it's an event stream API

## Description


## Testing
https://gist.github.com/rcoh/ce61429303a56fefd65ba05bd36d28fc works



----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent ce5f0aae
Loading
Loading
Loading
Loading
+1166 −0

File added.

Preview size limit exceeded, changes collapsed.

+2 −1
Original line number Diff line number Diff line
@@ -8,4 +8,5 @@ aws.services=
# List of services to generate Event Stream operations for:
aws.services.eventstream.allowlist=\
    aws-sdk-transcribestreaming,\
    aws-sdk-s3
    aws-sdk-s3,\
    aws-sdk-bedrockruntime
+60 −17
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ class HttpBoundProtocolPayloadGenerator(
        val (shape, payloadMemberName) = when (httpMessageType) {
            HttpMessageType.RESPONSE -> operationShape.outputShape(model) to
                httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

            HttpMessageType.REQUEST -> operationShape.inputShape(model) to
                httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        }
@@ -97,6 +98,7 @@ class HttpBoundProtocolPayloadGenerator(
                is DocumentShape, is StructureShape, is UnionShape -> ProtocolPayloadGenerator.PayloadMetadata(
                    takesOwnership = false,
                )

                is StringShape, is BlobShape -> ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = true)
                else -> UNREACHABLE("Unexpected payload target type: $type")
            }
@@ -110,8 +112,19 @@ class HttpBoundProtocolPayloadGenerator(
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        when (httpMessageType) {
            HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape, additionalPayloadContext)
            HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape, additionalPayloadContext)
            HttpMessageType.RESPONSE -> generateResponsePayload(
                writer,
                shapeName,
                operationShape,
                additionalPayloadContext,
            )

            HttpMessageType.REQUEST -> generateRequestPayload(
                writer,
                shapeName,
                operationShape,
                additionalPayloadContext,
            )
        }
    }

@@ -119,13 +132,20 @@ class HttpBoundProtocolPayloadGenerator(
        writer: RustWriter, shapeName: String, operationShape: OperationShape,
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        val payloadMemberName =
            httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

        if (payloadMemberName == null) {
            val serializerGenerator = protocol.structuredDataSerializer()
            generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape))
        } else {
            generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext)
            generatePayloadMemberSerializer(
                writer,
                shapeName,
                operationShape,
                payloadMemberName,
                additionalPayloadContext,
            )
        }
    }

@@ -133,13 +153,24 @@ class HttpBoundProtocolPayloadGenerator(
        writer: RustWriter, shapeName: String, operationShape: OperationShape,
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        val payloadMemberName =
            httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

        if (payloadMemberName == null) {
            val serializerGenerator = protocol.structuredDataSerializer()
            generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape))
            generateStructureSerializer(
                writer,
                shapeName,
                serializerGenerator.operationOutputSerializer(operationShape),
            )
        } else {
            generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext)
            generatePayloadMemberSerializer(
                writer,
                shapeName,
                operationShape,
                payloadMemberName,
                additionalPayloadContext,
            )
        }
    }

@@ -152,16 +183,24 @@ class HttpBoundProtocolPayloadGenerator(
    ) {
        val serializerGenerator = protocol.structuredDataSerializer()

        if (operationShape.isEventStream(model)) {
        if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
            val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
                writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName, additionalPayloadContext)
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                shapeName,
                additionalPayloadContext,
            )
        } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
            val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
                writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output", additionalPayloadContext)
            } else {
                throw CodegenException("Payload serializer for event streams with an invalid configuration")
            }
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                "output",
                additionalPayloadContext,
            )
        } else {
            val bodyMetadata = payloadMetadata(operationShape)
            val payloadMember = when (httpMessageType) {
@@ -268,6 +307,7 @@ class HttpBoundProtocolPayloadGenerator(
                                Vec::new()
                                """,
                            )

                            is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape))
                            is UnionShape -> rust("#T()", serializerGenerator.unsetUnion(targetShape))
                            else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported")
@@ -310,6 +350,7 @@ class HttpBoundProtocolPayloadGenerator(
                    rust("$payloadName.into_inner()")
                }
            }

            is StructureShape, is UnionShape -> {
                check(
                    !((targetShape as? UnionShape)?.isEventStream() ?: false),
@@ -320,12 +361,14 @@ class HttpBoundProtocolPayloadGenerator(
                    serializer.payloadSerializer(member),
                )
            }

            is DocumentShape -> {
                rust(
                    "#T($payloadName)",
                    serializer.documentSerializer(),
                )
            }

            else -> PANIC("Unexpected payload target type: $targetShape")
        }
    }