Unverified Commit 25def3a6 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

Use the same protocol dispatcher logic used in the client codegen (#827)



The server protocol dipatch and code geneneration is implemented by
having the HTTP logic inside ServerHttpProtocolGenerator and using
specification we can apply specific ser/de and code.
Protocols (like ServerRestJson) can inherit from ServerHttpProtocolGenerator
and tune their implementations.

Other class that have been added:
* The ProtocolLoader is implemented inside ServerProtocolLoader
  (ServerProtocol.kt)
  It matches the right protocol and casts it to
  a ServerProtocolGenerator that can be used inside the
  ProtocolGeneratorFactory.
* The ProtocolGenerator is implemented inside ServerProtocolGenerator (
  ServerProtocolGenerator.kt)
  It implements serverProtocolOperation() that can build the
  operation definition as we like. This method is than called by
  ServiceGenerator per every operation in the model.
* RestJson and RestJsonFactory are implemented in RestJson.kt and
  inherit from ServerHttpProtocolGenerator
* Implement the HTTP protocol definition and adapt tests

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>

Co-authored-by: default avatardavid-perez <d@vidp.dev>
Co-authored-by: default avatarRussell Cohen <rcoh@amazon.com>
parent 643f2eee
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -28,9 +28,10 @@ val CodegenTests = listOf(
)

/**
 * The fluent client is generated to prevent warnings in RustDoc since the client is
 * referenced by multiple documentations.
 * TODO: review client generation in the future.
 * `includeFluentClient` must be set to `false` as we are not generating all the supporting
 * code for it.
 * TODO: Review how can we make this a default in the server so that customers don't
 *       have to specify it.
 */
fun generateSmithyBuild(tests: List<CodegenTest>): String {
    val projections =
@@ -40,7 +41,7 @@ fun generateSmithyBuild(tests: List<CodegenTest>): String {
                "plugins": {
                    "rust-server-codegen": {
                      "codegen": {
                        "includeFluentClient": true
                        "includeFluentClient": false
                      },
                      "runtimeConfig": {
                        "relativePath": "${rootProject.projectDir.absolutePath}/rust-runtime"
+6 −6
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ import java.util.logging.Logger
 *  `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which
 *  enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models.
 */
class RustCodegenPlugin : SmithyBuildPlugin {
class RustCodegenServerPlugin : SmithyBuildPlugin {
    private val logger = Logger.getLogger(javaClass.name)

    override fun getName(): String = "rust-server-codegen"
@@ -35,22 +35,22 @@ class RustCodegenPlugin : SmithyBuildPlugin {
    override fun execute(context: PluginContext) {
        // Suppress extremely noisy logs about reserved words
        Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF
        // Discover `RustCodegenDecorators` on the classpath. `RustCodegenDectorator` return different types of
        // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of
        // customization. A customization is a function of:
        // - location (e.g. the mutate section of an operation)
        // - context (e.g. the of the operation)
        // - writer: The active RustWriter at the given location
        val codegenDecorator = CombinedCodegenDecorator.fromClasspath(context)

        // CodegenVisitor is the main driver of code generation that traverses the model and generates code
        CodegenVisitor(context, codegenDecorator).execute()
        // ServerCodegenVisitor is the main driver of code generation that traverses the model and generates code
        ServerCodegenVisitor(context, codegenDecorator).execute()
    }

    companion object {
        /** SymbolProvider
         * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider
         *
         * The Symbol provider is composed of a base `SymbolVisitor` which handles the core functionality, then is layered
         * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered
         * with other symbol providers, documented inline, to handle the full scope of Smithy types.
         */
        fun baseSymbolProvider(
@@ -63,7 +63,7 @@ class RustCodegenPlugin : SmithyBuildPlugin {
                .let {
                    EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model)
                }
                // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it) }
+16 −72
Original line number Diff line number Diff line
@@ -5,11 +5,9 @@

package software.amazon.smithy.rust.codegen.server.smithy

import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeVisitor
@@ -22,10 +20,8 @@ import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpDeserializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.RestJson1HttpSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.DefaultPublicModules
import software.amazon.smithy.rust.codegen.smithy.RustCrate
@@ -40,12 +36,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
@@ -53,7 +44,6 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBox
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger

@@ -61,7 +51,7 @@ import java.util.logging.Logger
 * Entrypoint for server-side code generation. This class will walk the in-memory model and
 * generate all the needed types by calling the accept() function on the available shapes.
 */
class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustCodegenDecorator) :
class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator: RustCodegenDecorator) :
    ShapeVisitor.Default<Unit>() {

    private val logger = Logger.getLogger(javaClass.name)
@@ -75,10 +65,6 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
    private val protocolGeneratorFactory: ProtocolGeneratorFactory<ProtocolGenerator>
    private val protocolGenerator: ProtocolGenerator

    private val httpSerializerGenerator: RestJson1HttpSerializerGenerator
    private val httpDeserializerGenerator: RestJson1HttpDeserializerGenerator
    private val httpBindingResolver: HttpBindingResolver

    init {
        val symbolVisitorConfig =
            SymbolVisitorConfig(
@@ -88,16 +74,16 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
        val baseModel = baselineTransform(context.model)
        val service = settings.getService(baseModel)
        val (protocol, generator) =
            ProtocolLoader(
            ServerProtocolLoader(
                codegenDecorator.protocols(
                    service.id,
                    ProtocolLoader.DefaultProtocols
                    ServerProtocolLoader.DefaultProtocols
                )
            )
                .protocolFor(context.model, service)
        protocolGeneratorFactory = generator
        model = generator.transformModel(codegenDecorator.transformModel(service, baseModel))
        val baseProvider = RustCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig)
        val baseProvider = RustCodegenServerPlugin.baseSymbolProvider(model, service, symbolVisitorConfig)
        symbolProvider =
            codegenDecorator.symbolProvider(generator.symbolProvider(model, baseProvider))

@@ -105,24 +91,6 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

        rustCrate = RustCrate(context.fileManifest, symbolProvider, DefaultPublicModules)
        protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)

        httpBindingResolver =
            HttpTraitHttpBindingResolver(
                codegenContext.model,
                ProtocolContentTypes.consistent("application/json"),
            )

        when (codegenContext.protocol) {
            RestJson1Trait.ID -> {
                httpSerializerGenerator =
                    RestJson1HttpSerializerGenerator(codegenContext, httpBindingResolver)
                httpDeserializerGenerator =
                    RestJson1HttpDeserializerGenerator(codegenContext, httpBindingResolver)
            }
            else -> {
                TODO("Protocol ${codegenContext.protocol} not supported yet")
            }
        }
    }

    /**
@@ -147,7 +115,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
    /**
     * Execute code generation
     *
     * 1. Load the service from `RustSettings`.
     * 1. Load the service from [RustSettings].
     * 2. Traverse every shape in the closure of the service.
     * 3. Loop through each shape and visit them (calling the override functions in this class)
     * 4. Call finalization tasks specified by decorators.
@@ -194,28 +162,11 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC

    override fun getDefault(shape: Shape?) {}

    /**
     * Operation Shape Visitor
     *
     * For each operation shape, generate the corresponding protocol implementation.
     */
    override fun operationShape(shape: OperationShape?) {
        logger.info("[rust-server-codegen] Generating operation $shape")
        if (shape != null) {
            rustCrate.useShapeWriter(shape) { writer ->
                shape.let {
                    httpDeserializerGenerator.render(writer, it)
                    httpSerializerGenerator.render(writer, it)
                }
            }
        }
    }

    /**
     * Structure Shape Visitor
     *
     * For each structure shape, generate:
     * - A Rust structure for the shape (`StructureGenerator`).
     * - A Rust structure for the shape ([StructureGenerator]).
     * - A builder for the shape.
     *
     * This function _does not_ generate any serializers.
@@ -224,7 +175,6 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
        logger.info("[rust-server-codegen] Generating a structure $shape")
        rustCrate.useShapeWriter(shape) { writer ->
            StructureGenerator(model, symbolProvider, writer, shape).render()
            if (!shape.hasTrait<SyntheticInputTrait>()) {
            val builderGenerator =
                BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
            builderGenerator.render(writer)
@@ -233,12 +183,11 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
            }
        }
    }
    }

    /**
     * String Shape Visitor
     *
     * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes.
     * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes.
     */
    override fun stringShape(shape: StringShape) {
        logger.info("[rust-server-codegen] Generating an enum $shape")
@@ -274,15 +223,10 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
     */
    override fun serviceShape(shape: ServiceShape) {
        logger.info("[rust-server-codegen] Generating a service $shape")
        ServiceGenerator(
        ServerServiceGenerator(
            rustCrate,
            protocolGenerator,
            ProtocolSupport(
                requestDeserialization = true,
                requestBodyDeserialization = true,
                responseSerialization = true,
                errorSerialization = true
            ),
            protocolGeneratorFactory.support(),
            codegenContext,
            codegenDecorator
        )
+21 −27
Original line number Diff line number Diff line
@@ -7,18 +7,22 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfigGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport

class ServiceGenerator(
/**
 * ServerServiceGenerator
 *
 * Service generator is the main codegeneration entry point for Smithy services. Individual structures and unions are
 * generated in codegen visitor, but this class handles all protocol-specific code generation (i.e. operations).
 */
class ServerServiceGenerator(
    private val rustCrate: RustCrate,
    private val protocolGenerator: ProtocolGenerator,
    private val protocolSupport: ProtocolSupport,
@@ -27,23 +31,23 @@ class ServiceGenerator(
) {
    private val index = TopDownIndex.of(context.model)

    /**
     * Render Service Specific code. Code will end up in different files via [useShapeWriter]. See `SymbolVisitor.kt`
     * which assigns a symbol location to each shape.
     *
     */
    fun render() {
        val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id }
        operations.map { operation ->
            rustCrate.useShapeWriter(operation) { operationWriter ->
                rustCrate.useShapeWriter(operation.inputShape(context.model)) { inputWriter ->
                    protocolGenerator.renderOperation(
                protocolGenerator.serverRenderOperation(
                    operationWriter,
                        inputWriter,
                    operation,
                    decorator.operationCustomizations(context, operation, listOf())
                )
                    // render protocol tests into `operation.rs` (note operationWriter vs.
                    // inputWriter)
                    ProtocolTestGenerator(context, protocolSupport, operation, operationWriter)
                ServerProtocolTestGenerator(context, protocolSupport, operation, operationWriter)
                    .render()
            }
            }
            rustCrate.withModule(RustModule.Error) { writer ->
                CombinedErrorGenerator(context.model, context.symbolProvider, operation)
                    .render(writer)
@@ -51,15 +55,5 @@ class ServiceGenerator(
        }

        TopLevelErrorGenerator(context, operations).render(rustCrate)

        rustCrate.withModule(RustModule.Config) { writer ->
            ServiceConfigGenerator.withBaseBehavior(
                context,
                extraCustomizations = decorator.configCustomizations(context, listOf())
            )
                .render(writer)
        }

        rustCrate.lib { it.write("pub use config::Config;") }
    }
}
+44 −53
Original line number Diff line number Diff line
@@ -12,7 +12,6 @@ import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.IdempotencyTokenTrait
import software.amazon.smithy.protocoltests.traits.AppliesTo
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
@@ -30,12 +29,13 @@ import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.server.smithy.protocols.HttpServerTraits
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.findMemberWithTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
@@ -45,17 +45,10 @@ import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger

data class ProtocolSupport(
    val requestDeserialization: Boolean,
    val requestBodyDeserialization: Boolean,
    val responseSerialization: Boolean,
    val errorSerialization: Boolean
)

/**
 * Generate protocol tests for an operation
 */
class ProtocolTestGenerator(
class ServerProtocolTestGenerator(
    private val codegenContext: CodegenContext,
    private val protocolSupport: ProtocolSupport,
    private val operationShape: OperationShape,
@@ -67,11 +60,19 @@ class ProtocolTestGenerator(
    private val outputShape = operationShape.outputShape(codegenContext.model)
    private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape)
    private val operationIndex = OperationIndex.of(codegenContext.model)
    private val operationMod = RuntimeType("operation", null, "crate")

    private val instantiator = with(codegenContext) {
        Instantiator(symbolProvider, model, runtimeConfig)
    }
    private val httpServerTraits = HttpServerTraits()

    private val codegenScope = arrayOf(
        "ParseStrictResponse" to RuntimeType.parseStrictResponse(codegenContext.runtimeConfig),
        "SerializeHttpResponse" to httpServerTraits.serializeHttpResponse(codegenContext.runtimeConfig),
        "ParseHttpRequest" to httpServerTraits.parseHttpRequest(codegenContext.runtimeConfig),
        "Bytes" to RuntimeType.Bytes,
        "SmithyHttp" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(),
        "Http" to CargoDependency.Http.asType(),
    )

    sealed class TestCase {
        abstract val testCase: HttpMessageTestCase
@@ -178,36 +179,22 @@ class ProtocolTestGenerator(
            rust("/* test case disabled for this protocol (not yet supported) */")
            return
        }
        val customToken = if (inputShape.findMemberWithTrait<IdempotencyTokenTrait>(codegenContext.model) != null) {
            """.make_token("00000000-0000-4000-8000-000000000000")"""
        } else ""
        rust(
            """let config = #T::Config::builder()$customToken.build();""",
            RuntimeType.Config
        )
        writeInline("let expected =")
        instantiator.render(this, inputShape, httpRequestTestCase.params)
        write(";")
        rust("""let op = expected.make_operation(&config).await.expect("operation failed to build");""")
        rust("let (http_request, parts) = op.into_request_response().0.into_parts();")
        rustTemplate("""let op = #{op}::new();""", "op" to operationSymbol)
        with(httpRequestTestCase) {
            host.orNull()?.also { host ->
                val withScheme = "http://$host"
                rust(
                rustTemplate(
                    """
                    let mut http_request = http_request;
                    let ep = #T::endpoint::Endpoint::mutable(#T::Uri::from_static(${withScheme.dq()}));
                    let ep = #{SmithyHttp}::endpoint::Endpoint::mutable(#{Http}::Uri::from_static(${withScheme.dq()}));
                    ep.set_endpoint(http_request.uri_mut(), parts.acquire().get());
                    """,
                    CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(), CargoDependency.Http.asType()
                    *codegenScope,
                )
            }
            rust(
                """
                assert_eq!(http_request.method(), ${method.dq()});
                assert_eq!(http_request.uri().path(), ${uri.dq()});
                """
            )
            resolvedHost.orNull()?.also { host ->
                rust("""assert_eq!(http_request.uri().host().expect("host should be set"), ${host.dq()});""")
            }
@@ -221,8 +208,7 @@ class ProtocolTestGenerator(
        if (protocolSupport.requestBodyDeserialization) {
            // "If no request body is defined, then no assertions are made about the body of the message."
            httpRequestTestCase.body.orNull()?.also { body ->
                val fnName = "deser_${operationShape.id.name.toSnakeCase()}_request"
                checkBody(this, fnName, body)
                checkBody(this, body, httpRequestTestCase.uri)
            }
        }

@@ -260,26 +246,19 @@ class ProtocolTestGenerator(
            rust("/* test case disabled for this protocol (not yet supported) */")
            return
        }
        val fnName = if (expectedShape.hasTrait<ErrorTrait>())
            "serialize_${operationShape.id.name.toSnakeCase()}_error"
        else
            "serialize_${operationShape.id.name.toSnakeCase()}_response"
        writeInline("let expected_output =")
        instantiator.render(this, expectedShape, testCase.params)
        write(";")
        write(
            """let http_response = #T(&expected_output).expect("failed to serialize response");""",
            operationMod.member(fnName)
        )
        rustTemplate(
            """
            use #{parse_http_response};
            let parser = #{op}::new();
            let parsed = parser.parse_loaded(&http_response);
            use #{ParseStrictResponse};
            use #{SerializeHttpResponse};
            let op = #{op}::new();
            let http_response = op.serialize(&expected_output).expect("unable to serialize response body");
            let parsed = op.parse(&http_response);
            """,
            *codegenScope,
            "op" to operationSymbol,
            "parse_http_response" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType()
                .member("response::ParseHttpResponse"),
        )
        if (expectedShape.hasTrait<ErrorTrait>()) {
            val errorSymbol = operationShape.errorSymbol(codegenContext.symbolProvider)
@@ -292,7 +271,7 @@ class ProtocolTestGenerator(
                rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
            }
        } else {
            rust("let parsed = parsed.unwrap();")
            rust("let parsed = parsed.expect(\"unable to parse response\");")
            outputShape.members().forEach { member ->
                val memberName = codegenContext.symbolProvider.toMemberName(member)
                if (member.isStreaming(codegenContext.model)) {
@@ -331,14 +310,26 @@ class ProtocolTestGenerator(
        basicCheck(forbidHeaders, rustWriter, "forbidden_headers", "forbid_headers")
    }

    private fun checkBody(rustWriter: RustWriter, fnName: String, body: String) {
        rustWriter.write(
            """let http_request = http_request.map(|body| #T::from(body.bytes().unwrap().to_vec()));""",
            RuntimeType.Bytes
    private fun checkBody(rustWriter: RustWriter, body: String, uri: String) {
        rustWriter.rustTemplate(
            """
            let http_request = http::Request::builder()
                .uri(${uri.dq()})
                .body(#{Bytes}::from_static(b${body.dq()}))
                .unwrap();
            """,
            "body" to body,
            "uri" to uri,
            *codegenScope,
        )
        rustWriter.write(
            """let body = #T(&http_request).expect("failed to parse request");""",
            operationMod.member(fnName)
        rustWriter.rustTemplate(
            """
            use #{ParseHttpRequest};
            let op = #{op}::new();
            let body = op.parse_loaded(&http_request).expect("failed to parse request");
            """,
            "op" to operationSymbol,
            *codegenScope,
        )
        if (body == "") {
            rustWriter.write("// No body")
Loading