Unverified Commit eac52eb6 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Fix event stream `:content-type` for struct messages (#3603)

Event stream operations with struct shaped messages were using the wrong
`:content-type` message header value, which I think wasn't caught before
since the supported AWS S3/Transcribe event stream operations don't
serialize struct messages. This PR fixes the message content type
serialization.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 5461e4fc
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -58,3 +58,9 @@ message = "SDK crates now set the `rust-version` property in their Cargo.toml fi
references = ["smithy-rs#3601"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Fix event stream `:content-type` message headers for struct messages. Note: this was the `:content-type` header on individual event message frames that was incorrect, not the HTTP `content-type` header for the initial request."
references = ["smithy-rs#3603"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" }
author = "jdisanti"
+3 −0
Original line number Diff line number Diff line
@@ -83,6 +83,9 @@ class AwsJsonHttpBindingResolver(
        "application/x-amz-json-${awsJsonVersion.value}"

    override fun responseContentType(operationShape: OperationShape): String = requestContentType(operationShape)

    override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
        ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/json")
}

/**
+4 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.core.smithy.protocols

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
@@ -38,6 +39,9 @@ class AwsQueryCompatibleHttpBindingResolver(

    override fun responseContentType(operationShape: OperationShape): String =
        awsJsonHttpBindingResolver.requestContentType(operationShape)

    override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
        awsJsonHttpBindingResolver.eventStreamMessageContentType(memberShape)
}

class AwsQueryCompatible(
+35 −3
Original line number Diff line number Diff line
@@ -8,8 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
@@ -98,6 +100,11 @@ interface HttpBindingResolver {
     * Determines the response content type for given [operationShape].
     */
    fun responseContentType(operationShape: OperationShape): String?

    /**
     * Determines the value of the event stream `:content-type` header based on union member
     */
    fun eventStreamMessageContentType(memberShape: MemberShape): String?
}

/**
@@ -108,12 +115,30 @@ data class ProtocolContentTypes(
    val requestDocument: String? = null,
    /** Response content type override for when the shape is a Document */
    val responseDocument: String? = null,
    /** EventStream content type */
    /** EventStream content type initial request/response content-type */
    val eventStreamContentType: String? = null,
    /** EventStream content type for struct message shapes (for `:content-type`) */
    val eventStreamMessageContentType: String? = null,
) {
    companion object {
        /** Create an instance of [ProtocolContentTypes] where all content types are the same */
        fun consistent(type: String) = ProtocolContentTypes(type, type, type)
        fun consistent(type: String) = ProtocolContentTypes(type, type, type, type)

        /**
         * Returns the event stream message `:content-type` for the given event stream union member shape.
         *
         * The `protocolContentType` is the content-type to use for non-string/non-blob shapes.
         */
        fun eventStreamMemberContentType(
            model: Model,
            memberShape: MemberShape,
            protocolContentType: String?,
        ): String? =
            when (model.expectShape(memberShape.target)) {
                is StringShape -> "text/plain"
                is BlobShape -> "application/octet-stream"
                else -> protocolContentType
            }
    }
}

@@ -121,7 +146,7 @@ data class ProtocolContentTypes(
 * An [HttpBindingResolver] that relies on the HttpTrait data in the Smithy models.
 */
open class HttpTraitHttpBindingResolver(
    model: Model,
    private val model: Model,
    private val contentTypes: ProtocolContentTypes,
) : HttpBindingResolver {
    private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model)
@@ -158,6 +183,9 @@ open class HttpTraitHttpBindingResolver(
            contentTypes.eventStreamContentType,
        ).orNull()

    override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
        ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, contentTypes.eventStreamMessageContentType)

    // Sort the members after extracting them from the map to have a consistent order
    private fun mappedBindings(bindings: Map<String, HttpBinding>): List<HttpBindingDescriptor> =
        bindings.values.map(::HttpBindingDescriptor).sortedBy { it.memberName }
@@ -172,6 +200,7 @@ open class StaticHttpBindingResolver(
    private val httpTrait: HttpTrait,
    private val requestContentType: String,
    private val responseContentType: String,
    private val eventStreamMessageContentType: String? = null,
) : HttpBindingResolver {
    private fun bindings(shape: ToShapeId?) =
        shape?.let { model.expectShape(it.toShapeId()) }?.members()
@@ -192,4 +221,7 @@ open class StaticHttpBindingResolver(
    override fun requestContentType(operationShape: OperationShape): String = requestContentType

    override fun responseContentType(operationShape: OperationShape): String = responseContentType

    override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
        ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, eventStreamMessageContentType)
}
+6 −10
Original line number Diff line number Diff line
@@ -197,7 +197,6 @@ class HttpBoundProtocolPayloadGenerator(
        if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
            val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                shapeName,
@@ -206,7 +205,6 @@ class HttpBoundProtocolPayloadGenerator(
        } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
            val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
            writer.serializeViaEventStream(
                operationShape,
                payloadMember,
                serializerGenerator,
                "output",
@@ -239,7 +237,6 @@ class HttpBoundProtocolPayloadGenerator(
    }

    private fun RustWriter.serializeViaEventStream(
        operationShape: OperationShape,
        memberShape: MemberShape,
        serializerGenerator: StructuredDataSerializerGenerator,
        outerName: String,
@@ -248,11 +245,10 @@ class HttpBoundProtocolPayloadGenerator(
        val memberName = symbolProvider.toMemberName(memberShape)
        val unionShape = model.expectShape(memberShape.target, UnionShape::class.java)

        val contentType =
            when (target) {
                CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape)
                CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape)
            }
        val payloadContentType =
            httpBindingResolver.eventStreamMessageContentType(memberShape)
                ?: throw CodegenException("event streams must set a content type")

        val errorMarshallerConstructorFn =
            EventStreamErrorMarshallerGenerator(
                model,
@@ -261,7 +257,7 @@ class HttpBoundProtocolPayloadGenerator(
                symbolProvider,
                unionShape,
                serializerGenerator,
                contentType ?: throw CodegenException("event streams must set a content type"),
                payloadContentType,
            ).render()
        val marshallerConstructorFn =
            EventStreamMarshallerGenerator(
@@ -271,7 +267,7 @@ class HttpBoundProtocolPayloadGenerator(
                symbolProvider,
                unionShape,
                serializerGenerator,
                contentType,
                payloadContentType,
            ).render()

        // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
Loading