Unverified Commit 41d948c5 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Serialize Request bodies for restJson (#255)

* Serialize Request Bodies for RestJson

* Move the bodies to the serializer module

* RestJson body CR feedback and bug fixes
parent 4f958f37
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -36,6 +36,10 @@ val CodegenTests = listOf(
        "aws.protocoltests.restjson#RestJson",
        "rest_json"
    ),
    CodegenTest(
        "aws.protocoltests.restjson#RestJsonExtras",
        "rest_json_extas"
    ),
    CodegenTest(
        "crate#Config",
        "naming_test", """
+59 −0
Original line number Diff line number Diff line
$version: "1.0"

namespace aws.protocoltests.restjson

use aws.protocols#restJson1
use aws.api#service
use smithy.test#httpRequestTests


/// A REST JSON service that sends JSON requests and responses.
@service(sdkId: "Rest Json Protocol")
@restJson1
service RestJsonExtras {
    version: "2019-12-16",
    operations: [EnumPayload, StringPayload]
}

@http(uri: "/EnumPayload", method: "POST")
@httpRequestTests([
    {
        id: "EnumPayload",
        uri: "/EnumPayload",
        body: "enumvalue",
        params: { payload: "enumvalue" },
        method: "POST",
        protocol: "aws.protocols#restJson1"
    }
])
operation EnumPayload {
    input: EnumPayloadInput
}

structure EnumPayloadInput {
    @httpPayload
    payload: StringEnum
}

@enum([{"value": "enumvalue", "name": "V"}])
string StringEnum

@http(uri: "/StringPayload", method: "POST")
@httpRequestTests([
    {
        id: "StringPayload",
        uri: "/StringPayload",
        body: "rawstring",
        params: { payload: "rawstring" },
        method: "POST",
        protocol: "aws.protocols#restJson1"
    }
])
operation StringPayload {
    input: StringPayloadInput
}

structure StringPayloadInput {
    @httpPayload
    payload: String
}
+4 −0
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@ import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@@ -231,6 +233,7 @@ class SymbolVisitor(
        val isError = shape.hasTrait(ErrorTrait::class.java)
        val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
        val isOutput = shape.hasTrait(SyntheticOutputTrait::class.java)
        val isBody = shape.hasTrait(InputBodyTrait::class.java) || shape.hasTrait(OutputBodyTrait::class.java)
        val name = StringUtils.capitalize(shape.id.name).letIf(isError && config.codegenConfig.renameExceptions) {
            // TODO: Do we want to do this?
            // https://github.com/awslabs/smithy-rs/issues/77
@@ -241,6 +244,7 @@ class SymbolVisitor(
            isError -> builder.locatedIn(Errors)
            isInput -> builder.locatedIn(Inputs)
            isOutput -> builder.locatedIn(Outputs)
            isBody -> builder.locatedIn(Serializers)
            else -> builder.locatedIn(Models)
        }.build()
    }
+31 −7
Original line number Diff line number Diff line
@@ -49,13 +49,24 @@ abstract class HttpProtocolGenerator(
) {
    private val symbolProvider = protocolConfig.symbolProvider
    private val model = protocolConfig.model
    fun renderOperation(operationWriter: RustWriter, inputWriter: RustWriter, operationShape: OperationShape, customizations: List<OperationCustomization>) {
    fun renderOperation(
        operationWriter: RustWriter,
        inputWriter: RustWriter,
        operationShape: OperationShape,
        customizations: List<OperationCustomization>
    ) {
        /* if (operationShape.hasTrait(EndpointTrait::class.java)) {
            TODO("https://github.com/awslabs/smithy-rs/issues/197")
        } */
        val inputShape = operationShape.inputShape(model)
        val inputSymbol = symbolProvider.toSymbol(inputShape)
        val builderGenerator = OperationInputBuilderGenerator(model, symbolProvider, operationShape, protocolConfig.moduleName, customizations)
        val builderGenerator = OperationInputBuilderGenerator(
            model,
            symbolProvider,
            operationShape,
            protocolConfig.moduleName,
            customizations
        )
        builderGenerator.render(inputWriter)
        // impl OperationInputShape { ... }

@@ -63,7 +74,7 @@ abstract class HttpProtocolGenerator(
            toHttpRequestImpl(this, operationShape, inputShape)
            val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body
            val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) }
            toBodyImpl(this, inputShape, body)
            toBodyImpl(this, inputShape, body, operationShape)
            // TODO: streaming shapes need special support
            rustBlock(
                "pub fn assemble(builder: #1T, body: #3T) -> #2T<#3T>",
@@ -130,14 +141,18 @@ abstract class HttpProtocolGenerator(
        }
    }

    protected fun fromResponseFun(implBlockWriter: RustWriter, operationShape: OperationShape, f: RustWriter.() -> Unit) {
    protected fun fromResponseFun(
        implBlockWriter: RustWriter,
        operationShape: OperationShape,
        block: RustWriter.() -> Unit
    ) {
        implBlockWriter.rustBlock(
            "fn from_response(response: &#T<impl AsRef<[u8]>>) -> Result<#T, #T>",
            RuntimeType.Http("response::Response"),
            symbolProvider.toSymbol(operationShape.outputShape(model)),
            operationShape.errorSymbol(symbolProvider)
        ) {
            f(this)
            block(this)
        }
    }

@@ -150,12 +165,21 @@ abstract class HttpProtocolGenerator(
     *
     * Your implementation MUST call [bodyBuilderFun] to create the public method.
     */
    abstract fun toBodyImpl(implBlockWriter: RustWriter, inputShape: StructureShape, inputBody: StructureShape?)
    abstract fun toBodyImpl(
        implBlockWriter: RustWriter,
        inputShape: StructureShape,
        inputBody: StructureShape?,
        operationShape: OperationShape
    )

    /**
     * Add necessary methods to the impl block for the input shape.
     *
     * Your implementation MUST call [httpBuilderFun] to create the public method.
     */
    abstract fun toHttpRequestImpl(implBlockWriter: RustWriter, operationShape: OperationShape, inputShape: StructureShape)
    abstract fun toHttpRequestImpl(
        implBlockWriter: RustWriter,
        operationShape: OperationShape,
        inputShape: StructureShape
    )
}
+5 −3
Original line number Diff line number Diff line
@@ -172,7 +172,7 @@ class HttpProtocolTestGenerator(
        checkRequiredHeaders(this, httpRequestTestCase.requireHeaders)
        if (protocolSupport.requestBodySerialization) {
            // "If no request body is defined, then no assertions are made about the body of the message."
            httpRequestTestCase.body.orNull()?.let { body ->
            httpRequestTestCase.body.orNull()?.also { body ->
                checkBody(this, body, httpRequestTestCase.bodyMediaType.orNull())
            }
        }
@@ -253,7 +253,7 @@ class HttpProtocolTestGenerator(
        rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""")
        if (body == "") {
            rustWriter.write("// No body")
            rustWriter.write("assert!(&body.is_empty());")
            rustWriter.write("assert_eq!(std::str::from_utf8(body).unwrap(), ${"".dq()});")
        } else {
            // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
            assertOk(rustWriter) {
@@ -383,7 +383,9 @@ class HttpProtocolTestGenerator(
        // or because they are flaky
        private val DisableTests = setOf(
            // This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
            "AwsJson11Enums"
            "AwsJson11Enums",
            "RestJsonJsonEnums",
            "RestJsonLists"
        )
    }
}
Loading