Unverified Commit 138320e9 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add restJson error parsing & prefix headers support (#274)

* restJson1 Header Deserialization Support

This adds support for the `@httpHeader` trait during request deserialization. This is facilitated via a `Read` abstraction which allows parsers to consume part of an input and return the remaining input. We use this to enable things like `HttpDates` which actually include commas to be parsed in a comma delimited fashion.

The internal libraries may eventually be refactored to reduce the amount of generics if it proves to be a compiler bottleneck.

* Add support for parsing request bodies

* Add response code parsing

* Lots of refactorings to restJson response parsing

* Fix some clippy lints

* Add document deserialization support and delete the failing tests

* Add new error handling (all RestJson protocol tests passing!)

* Add prefix headers parsing support

* Fix Rust Runtime clippy warnings

* Cleanup some more dead code / clippy fixes

* Update external functions to take HeaderMap instead of the value iterator

* Hard error on unreconized binding location

* Refactor parse function generation to be on-demand

* One more round of refactoring

* Few more cleanups
parent 89d984df
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -168,6 +168,8 @@ typealias Writable = RustWriter.() -> Unit
 */
fun writable(w: Writable): Writable = w

fun writable(w: String): Writable = writable { rust(w) }

class RustWriter private constructor(
    private val filename: String,
    val namespace: String,
@@ -344,7 +346,7 @@ class RustWriter private constructor(
                    addDepsRecursively(t)
                    t.rustType().render(fullyQualified = true)
                }
                else -> throw CodegenException("Invalid type provided to RustSymbolFormatter")
                else -> throw CodegenException("Invalid type provided to RustSymbolFormatter: $t")
            }
        }
    }
+1 −1
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
            )
        )
        try {
            "cargo fmt".runCommand(fileManifest.baseDir)
            "cargo fmt".runCommand(fileManifest.baseDir, timeout = 5)
        } catch (_: CommandFailed) {
            logger.warning("Generated output did not parse [${service.id}]")
        }
+0 −3
Original line number Diff line number Diff line
@@ -364,9 +364,6 @@ class HttpProtocolTestGenerator(
            FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),

            FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Response),
        )
        private val RunOnly: Set<String>? = null

+57 −26
Original line number Diff line number Diff line
@@ -28,16 +28,19 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait
import software.amazon.smithy.model.traits.IdempotencyTokenTrait
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.util.dq
@@ -55,19 +58,23 @@ class Instantiator(
    private val runtimeConfig: RuntimeConfig
) {

    fun render(writer: RustWriter, shape: Shape, arg: Node) {
    // The Rust HTTP library lower cases headers but Smithy protocol tests
    // contain httpPrefix headers with uppercase keys
    data class Ctx(val lowercaseMapKeys: Boolean)

    fun render(writer: RustWriter, shape: Shape, arg: Node, ctx: Ctx = Ctx(false)) {
        when (shape) {
            // Compound Shapes
            is StructureShape -> renderStructure(writer, shape, arg as ObjectNode)
            is UnionShape -> renderUnion(writer, shape, arg as ObjectNode)
            is StructureShape -> renderStructure(writer, shape, arg as ObjectNode, ctx)
            is UnionShape -> renderUnion(writer, shape, arg as ObjectNode, ctx)

            // Collections
            is ListShape -> renderList(writer, shape, arg as ArrayNode)
            is MapShape -> renderMap(writer, shape, arg as ObjectNode)
            is SetShape -> renderSet(writer, shape, arg as ArrayNode)
            is ListShape -> renderList(writer, shape, arg as ArrayNode, ctx)
            is MapShape -> renderMap(writer, shape, arg as ObjectNode, ctx)
            is SetShape -> renderSet(writer, shape, arg as ArrayNode, ctx)

            // Members, supporting potentially optional members
            is MemberShape -> renderMember(writer, shape, arg)
            is MemberShape -> renderMember(writer, shape, arg, ctx)

            // Wrapped Shapes
            is TimestampShape -> writer.write(
@@ -109,7 +116,8 @@ class Instantiator(
    private fun renderMember(
        writer: RustWriter,
        shape: MemberShape,
        arg: Node
        arg: Node,
        ctx: Ctx
    ) {
        val target = model.expectShape(shape.target)
        val symbol = symbolProvider.toSymbol(shape)
@@ -120,19 +128,32 @@ class Instantiator(
            writer.write("None")
        } else {
            writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) {
                render(this, target, arg)
                writer.conditionalBlock(
                    "Box::new(",
                    ")",
                    conditional = symbol.rustType().stripOuter<RustType.Option>() is RustType.Box
                ) {
                    render(
                        this,
                        target,
                        arg,
                        ctx.letIf(shape.getMemberTrait(model, HttpPrefixHeadersTrait::class.java).isPresent) {
                            ctx.copy(lowercaseMapKeys = true)
                        }
                    )
                }
            }
        }
    }

    private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode) {
    private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) {
        if (symbolProvider.toSymbol(shape).rustType() is RustType.HashSet) {
            if (!data.isEmpty) {
                writer.rustBlock("") {
                    write("let mut ret = #T::new();", RustType.HashSet.RuntimeType)
                    data.forEach { v ->
                        withBlock("ret.insert(", ");") {
                            renderMember(this, shape.member, v)
                            renderMember(this, shape.member, v, ctx)
                        }
                    }
                    write("ret")
@@ -141,7 +162,7 @@ class Instantiator(
                writer.write("#T::new()", RustType.HashSet.RuntimeType)
            }
        } else {
            renderList(writer, shape, data)
            renderList(writer, shape, data, ctx)
        }
    }

@@ -157,14 +178,19 @@ class Instantiator(
    private fun renderMap(
        writer: RustWriter,
        shape: MapShape,
        data: ObjectNode
        data: ObjectNode,
        ctx: Ctx,
    ) {
        val lowercase = when (ctx.lowercaseMapKeys) {
            true -> ".to_ascii_lowercase()"
            else -> ""
        }
        if (data.members.isNotEmpty()) {
            writer.rustBlock("") {
                write("let mut ret = #T::new();", RustType.HashMap.RuntimeType)
                data.members.forEach { (k, v) ->
                    withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") {
                        renderMember(this, shape.value, v)
                    withBlock("ret.insert(${k.value.dq()}.to_string()$lowercase,", ");") {
                        renderMember(this, shape.value, v, ctx)
                    }
                }
                write("ret")
@@ -182,7 +208,8 @@ class Instantiator(
    private fun renderUnion(
        writer: RustWriter,
        shape: UnionShape,
        data: ObjectNode
        data: ObjectNode,
        ctx: Ctx
    ) {
        val unionSymbol = symbolProvider.toSymbol(shape)
        check(data.members.size == 1)
@@ -194,7 +221,7 @@ class Instantiator(
        writer.write("#T::${memberName.toPascalCase()}", unionSymbol)
        // unions should specify exactly one member
        writer.withBlock("(", ")") {
            render(this, member, variant.value)
            render(this, member, variant.value, ctx)
        }
    }

@@ -206,11 +233,12 @@ class Instantiator(
    private fun renderList(
        writer: RustWriter,
        shape: CollectionShape,
        data: ArrayNode
        data: ArrayNode,
        ctx: Ctx
    ) {
        writer.withBlock("vec![", "]") {
            data.elements.forEach { v ->
                renderMember(this, shape.member, v)
                renderMember(this, shape.member, v, ctx)
                write(",")
            }
        }
@@ -239,7 +267,8 @@ class Instantiator(
    private fun renderStructure(
        writer: RustWriter,
        shape: StructureShape,
        data: ObjectNode
        data: ObjectNode,
        ctx: Ctx
    ) {
        writer.rustBlock("") {
            val isSyntheticInput = shape.hasTrait(SyntheticInputTrait::class.java)
@@ -259,12 +288,14 @@ class Instantiator(
            }
            writer.write("#T::builder()", symbolProvider.toSymbol(shape))
            data.members.forEach { (key, value) ->
                val (memberShape, targetShape) = getMember(shape, key)
                val func = symbolProvider.toMemberName(memberShape)
                if (!value.isNullNode) {
                    writer.withBlock(".$func(", ")") {
                        render(this, targetShape, value)
                    }
                val memberShape = shape.expectMember(key.value)
                writer.withBlock(".${memberShape.setterName()}(", ")") {
                    renderMember(
                        this,
                        memberShape,
                        value,
                        ctx
                    )
                }
            }
            if (isSyntheticInput) {
+70 −22
Original line number Diff line number Diff line
@@ -10,7 +10,10 @@ import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
@@ -54,18 +57,60 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
     * }
     * ```
     */
    fun generateDeserializeHeaderFn(binding: HttpBinding, writer: RustWriter): String {
    fun generateDeserializeHeaderFn(binding: HttpBinding): RuntimeType {
        check(binding.location == HttpBinding.Location.HEADER)
        val outputT = symbolProvider.toSymbol(binding.member)
        val fnName = "deser_header_${operationShape.id.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}"
        return RuntimeType.forInlineFun(fnName, "http_serde") { writer ->
            writer.rustBlock(
                "pub fn $fnName(header_map: &#T::HeaderMap) -> Result<#T, #T::ParseError>",
                RuntimeType.http,
                outputT,
                headerUtil
            ) {
                rust("let headers = header_map.get_all(${binding.locationName.dq()}).iter();")
                deserializeFromHeader(model.expectShape(binding.member.target), binding.member)
            }
        }
    }

    fun generateDeserializePrefixHeaderFn(binding: HttpBinding): RuntimeType {
        check(binding.location == HttpBinding.Location.PREFIX_HEADERS)
        val outputT = symbolProvider.toSymbol(binding.member)
        val fnName = "parse_from_header_${binding.memberName.toSnakeCase()}"
        check(outputT.rustType().stripOuter<RustType.Option>() is RustType.HashMap) { outputT.rustType() }
        val target = model.expectShape(binding.member.target)
        check(target is MapShape)
        val fnName = "deser_prefix_header_${operationShape.id.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}"
        val inner = RuntimeType.forInlineFun("${fnName}_inner", "http_serde_inner") {
            it.rustBlock(
                "pub fn ${fnName}_inner(headers: #T::header::ValueIter<http::HeaderValue>) -> Result<Option<#T>, #T::ParseError>",
                RuntimeType.http,
                symbolProvider.toSymbol(model.expectShape(target.value.target)),
                headerUtil
            ) {
                deserializeFromHeader(model.expectShape(target.value.target), binding.member)
            }
        }
        return RuntimeType.forInlineFun(fnName, "http_serde") { writer ->
            writer.rustBlock(
            "fn $fnName(headers: &#T::HeaderMap) -> Result<#T, #T::ParseError>",
                "pub fn $fnName(header_map: &#T::HeaderMap) -> Result<#T, #T::ParseError>",
                RuntimeType.http,
                outputT,
                headerUtil
            ) {
            deserializeFromHeader(binding)
                rust(
                    """
                    let headers = #T::headers_for_prefix(&header_map, ${binding.locationName.dq()});
                    let out: Result<_, _> = headers.map(|(key, header_name)| {
                        let values = header_map.get_all(header_name);
                        #T(values.iter()).map(|v| (key.to_string(), v.unwrap()))
                    }).collect();
                    out.map(Some)
                """,
                    headerUtil, inner
                )
            }
        }
        return fnName
    }

    /**
@@ -74,19 +119,24 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
    fun generateDeserializePayloadFn(
        binding: HttpBinding,
        errorT: RuntimeType,
        rustWriter: RustWriter,
        // Deserialize a single structure or union member marked as a payload
        structuredHandler: RustWriter.(String) -> Unit,
        // Deserialize a document type marked as a payload
        docHandler: RustWriter.(String) -> Unit
    ): String {
    ): RuntimeType {
        check(binding.location == HttpBinding.Location.PAYLOAD)
        val outputT = symbolProvider.toSymbol(binding.member)
        val fnName = "parse_from_payload_${binding.memberName.toSnakeCase()}"
        rustWriter.rustBlock("fn $fnName(body: &[u8]) -> Result<#T, #T>", outputT, errorT) {
            deserializePayloadBody(binding, errorT, structuredHandler = structuredHandler, docShapeHandler = docHandler)
        val fnName = "deser_payload_${operationShape.id.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}"
        return RuntimeType.forInlineFun(fnName, "http_serde") { rustWriter ->
            rustWriter.rustBlock("pub fn $fnName(body: &[u8]) -> Result<#T, #T>", outputT, errorT) {
                deserializePayloadBody(
                    binding,
                    errorT,
                    structuredHandler = structuredHandler,
                    docShapeHandler = docHandler
                )
            }
        }
        return fnName
    }

    private fun RustWriter.deserializePayloadBody(
@@ -129,9 +179,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
    /** Parse a value from a header
     * This function produces an expression which produces the precise output type required by the output shape
     */
    private fun RustWriter.deserializeFromHeader(binding: HttpBinding) {
        check(binding.location == HttpBinding.Location.HEADER)
        val targetType = model.expectShape(binding.member.target)
    private fun RustWriter.deserializeFromHeader(targetType: Shape, memberShape: MemberShape) {
        val rustType = symbolProvider.toSymbol(targetType).rustType().stripOuter<RustType.Option>()
        val (coreType, coreShape) = if (targetType is CollectionShape) {
            rustType.stripOuter<RustType.Container>() to model.expectShape(targetType.member.target)
@@ -142,19 +190,19 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
        if (coreType == instant) {
            val timestampFormat =
                index.determineTimestampFormat(
                    binding.member,
                    memberShape,
                    HttpBinding.Location.HEADER,
                    defaultTimestampFormat
                )
            val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
            rust(
                "let $parsedValue: Vec<${coreType.render(true)}> = #T::many_dates(&headers, ${binding.locationName.dq()}, #T)?;",
                "let $parsedValue: Vec<${coreType.render(true)}> = #T::many_dates(headers, #T)?;",
                headerUtil,
                timestampFormatType
            )
        } else {
            rust(
                "let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(&headers, ${binding.locationName.dq()})?;",
                "let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(headers)?;",
                headerUtil
            )
            if (coreShape.hasTrait(MediaTypeTrait::class.java)) {
Loading