Unverified Commit 562d1787 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

0.56 cherry-picks (#3035)

- enable bedrock runtime
- ignore flaky tests

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 40aaa1e4
Loading
Loading
Loading
Loading
+1166 −0

File added.

Preview size limit exceeded, changes collapsed.

+2 −1
Original line number Diff line number Diff line
@@ -8,4 +8,5 @@ aws.services=
# List of services to generate Event Stream operations for:
aws.services.eventstream.allowlist=\
    aws-sdk-transcribestreaming,\
    aws-sdk-s3
    aws-sdk-s3,\
    aws-sdk-bedrockruntime
+60 −17
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ class HttpBoundProtocolPayloadGenerator(
        val (shape, payloadMemberName) = when (httpMessageType) {
            HttpMessageType.RESPONSE -> operationShape.outputShape(model) to
                httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

            HttpMessageType.REQUEST -> operationShape.inputShape(model) to
                httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        }
@@ -97,6 +98,7 @@ class HttpBoundProtocolPayloadGenerator(
                is DocumentShape, is StructureShape, is UnionShape -> ProtocolPayloadGenerator.PayloadMetadata(
                    takesOwnership = false,
                )

                is StringShape, is BlobShape -> ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = true)
                else -> UNREACHABLE("Unexpected payload target type: $type")
            }
@@ -110,8 +112,19 @@ class HttpBoundProtocolPayloadGenerator(
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        when (httpMessageType) {
            HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape, additionalPayloadContext)
            HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape, additionalPayloadContext)
            HttpMessageType.RESPONSE -> generateResponsePayload(
                writer,
                shapeName,
                operationShape,
                additionalPayloadContext,
            )

            HttpMessageType.REQUEST -> generateRequestPayload(
                writer,
                shapeName,
                operationShape,
                additionalPayloadContext,
            )
        }
    }

@@ -119,13 +132,20 @@ class HttpBoundProtocolPayloadGenerator(
        writer: RustWriter, shapeName: String, operationShape: OperationShape,
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        val payloadMemberName =
            httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

        if (payloadMemberName == null) {
            val serializerGenerator = protocol.structuredDataSerializer()
            generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape))
        } else {
            generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext)
            generatePayloadMemberSerializer(
                writer,
                shapeName,
                operationShape,
                payloadMemberName,
                additionalPayloadContext,
            )
        }
    }

@@ -133,13 +153,24 @@ class HttpBoundProtocolPayloadGenerator(
        writer: RustWriter, shapeName: String, operationShape: OperationShape,
        additionalPayloadContext: AdditionalPayloadContext,
    ) {
        val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName
        val payloadMemberName =
            httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName

        if (payloadMemberName == null) {
            val serializerGenerator = protocol.structuredDataSerializer()
            generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape))
            generateStructureSerializer(
                writer,
                shapeName,
                serializerGenerator.operationOutputSerializer(operationShape),
            )
        } else {
            generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext)
            generatePayloadMemberSerializer(
                writer,
                shapeName,
                operationShape,
                payloadMemberName,
                additionalPayloadContext,
            )
        }
    }

@@ -152,16 +183,24 @@ class HttpBoundProtocolPayloadGenerator(
    ) {
        val serializerGenerator = protocol.structuredDataSerializer()

        if (operationShape.isEventStream(model)) {
        if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
            val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
                writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName, additionalPayloadContext)
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                shapeName,
                additionalPayloadContext,
            )
        } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
            val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
                writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output", additionalPayloadContext)
            } else {
                throw CodegenException("Payload serializer for event streams with an invalid configuration")
            }
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                "output",
                additionalPayloadContext,
            )
        } else {
            val bodyMetadata = payloadMetadata(operationShape)
            val payloadMember = when (httpMessageType) {
@@ -268,6 +307,7 @@ class HttpBoundProtocolPayloadGenerator(
                                Vec::new()
                                """,
                            )

                            is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape))
                            is UnionShape -> rust("#T()", serializerGenerator.unsetUnion(targetShape))
                            else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported")
@@ -310,6 +350,7 @@ class HttpBoundProtocolPayloadGenerator(
                    rust("$payloadName.into_inner()")
                }
            }

            is StructureShape, is UnionShape -> {
                check(
                    !((targetShape as? UnionShape)?.isEventStream() ?: false),
@@ -320,12 +361,14 @@ class HttpBoundProtocolPayloadGenerator(
                    serializer.payloadSerializer(member),
                )
            }

            is DocumentShape -> {
                rust(
                    "#T($payloadName)",
                    serializer.documentSerializer(),
                )
            }

            else -> PANIC("Unexpected payload target type: $targetShape")
        }
    }
+2 −0
Original line number Diff line number Diff line
@@ -142,6 +142,7 @@ mod tests {
        use super::*;

        #[tokio::test]
        #[ignore]
        async fn test_rustls_connector_can_make_http_requests() {
            let conn = Adapter::builder().build(https());
            let conn = DynConnector::new(conn);
@@ -151,6 +152,7 @@ mod tests {
        }

        #[tokio::test]
        #[ignore]
        async fn test_rustls_connector_can_make_https_requests() {
            let conn = Adapter::builder().build(https());
            let conn = DynConnector::new(conn);