Unverified Commit f27aa546 authored by Weihang Lo's avatar Weihang Lo Committed by GitHub
Browse files

Test `httpRequestTests` against actual Services (#1708)



* Make Instantiator generate default values for required field on demand

* Move looping over operations into ServerProtocolTestGenerator

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Add protocol test helper functions

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Add method param to construct http request

* Put request validation logic inside closure

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Make protocol test response instantiate with default values

* Add module meta for helper module

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: default avatardavid-perez <d@vidp.dev>

* Address most style suggestions

* add companion object for attribute #[allow(dead_code)]

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Use writable to make code readable

* recursively call `filldefaultValue`

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>

* Exercise with `OperationExtension`

* Temporary protocol tests fix for awslabs/smithy#1391

Missing `X-Amz-Target` in response header

* Add `X-Amz-Target` for common models

Signed-off-by: default avatarWeihang Lo <weihanglo@users.noreply.github.com>
Co-authored-by: default avatardavid-perez <d@vidp.dev>
Co-authored-by: default avatarHarry Barber <hlbarber@amazon.co.uk>
parent e009f3f4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -384,6 +384,7 @@ sealed class Attribute {
         */
        val NonExhaustive = Custom("non_exhaustive")
        val AllowUnusedMut = Custom("allow(unused_mut)")
        val AllowDeadCode = Custom("allow(dead_code)")
        val DocHidden = Custom("doc(hidden)")
        val DocInline = Custom("doc(inline)")
    }
+62 −8
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.Node
@@ -68,13 +69,19 @@ class Instantiator(
        val streaming: Boolean,
        // Whether we are instantiating with a Builder, in which case all setters take Option
        val builder: Boolean,
        // Fill out `required` fields with a default value.
        val defaultsForRequiredFields: Boolean,
    )

    companion object {
        fun defaultContext() = Ctx(lowercaseMapKeys = false, streaming = false, builder = false, defaultsForRequiredFields = false)
    }

    fun render(
        writer: RustWriter,
        shape: Shape,
        arg: Node,
        ctx: Ctx = Ctx(lowercaseMapKeys = false, streaming = false, builder = false),
        ctx: Ctx = defaultContext(),
    ) {
        when (shape) {
            // Compound Shapes
@@ -222,14 +229,23 @@ class Instantiator(
     */
    private fun renderUnion(writer: RustWriter, shape: UnionShape, data: ObjectNode, ctx: Ctx) {
        val unionSymbol = symbolProvider.toSymbol(shape)

        val variant = if (ctx.defaultsForRequiredFields && data.members.isEmpty()) {
            val (name, memberShape) = shape.allMembers.entries.first()
            val targetShape = model.expectShape(memberShape.target)
            Node.from(name) to fillDefaultValue(targetShape)
        } else {
            check(data.members.size == 1)
        val variant = data.members.iterator().next()
        val memberName = variant.key.value
            val entry = data.members.iterator().next()
            entry.key to entry.value
        }

        val memberName = variant.first.value
        val member = shape.expectMember(memberName)
        writer.write("#T::${symbolProvider.toMemberName(member)}", unionSymbol)
        // unions should specify exactly one member
        writer.withBlock("(", ")") {
            renderMember(this, member, variant.value, ctx)
            renderMember(this, member, variant.second, ctx)
        }
    }

@@ -267,16 +283,54 @@ class Instantiator(
     * ```
     */
    private fun renderStructure(writer: RustWriter, shape: StructureShape, data: ObjectNode, ctx: Ctx) {
        writer.write("#T::builder()", symbolProvider.toSymbol(shape))
        data.members.forEach { (key, value) ->
            val memberShape = shape.expectMember(key.value)
        fun renderMemberHelper(memberShape: MemberShape, value: Node) {
            writer.withBlock(".${memberShape.setterName()}(", ")") {
                renderMember(this, memberShape, value, ctx)
            }
        }

        writer.write("#T::builder()", symbolProvider.toSymbol(shape))
        if (ctx.defaultsForRequiredFields) {
            shape.allMembers.entries
                .filter { (name, memberShape) ->
                    memberShape.isRequired && !data.members.containsKey(Node.from(name))
                }
                .forEach { (_, memberShape) ->
                    renderMemberHelper(memberShape, fillDefaultValue(memberShape))
                }
        }

        data.members.forEach { (key, value) ->
            val memberShape = shape.expectMember(key.value)
            renderMemberHelper(memberShape, value)
        }
        writer.write(".build()")
        if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) {
            writer.write(".unwrap()")
        }
    }

    /**
     * Returns a default value for a shape.
     *
     * Warning: this method does not take into account any constraint traits attached to the shape.
     */
    private fun fillDefaultValue(shape: Shape): Node = when (shape) {
        is MemberShape -> fillDefaultValue(model.expectShape(shape.target))

        // Aggregate shapes.
        is StructureShape -> Node.objectNode()
        is UnionShape -> Node.objectNode()
        is CollectionShape -> Node.arrayNode()
        is MapShape -> Node.objectNode()

        // Simple Shapes
        is TimestampShape -> Node.from(0) // Number node for timestamp
        is BlobShape -> Node.from("") // String node for bytes
        is StringShape -> Node.from("")
        is NumberShape -> Node.from(0)
        is BooleanShape -> Node.from(false)
        is DocumentShape -> Node.objectNode()
        else -> throw CodegenException("Unrecognized shape `$shape`")
    }
}
+82 −0
Original line number Diff line number Diff line
@@ -66,6 +66,41 @@ class InstantiatorTest {
            member: WithBox,
            value: Integer
        }

        structure MyStructRequired {
            @required
            str: String,
            @required
            primitiveInt: PrimitiveInteger,
            @required
            int: Integer,
            @required
            ts: Timestamp,
            @required
            byte: Byte
            @required
            union: NestedUnion,
            @required
            structure: NestedStruct,
            @required
            list: MyList,
            @required
            map: NestedMap,
            @required
            doc: Document
        }

        union NestedUnion {
            struct: NestedStruct,
            int: Integer
        }

        structure NestedStruct {
            @required
            str: String,
            @required
            num: Integer
        }
    """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) }

    private val symbolProvider = testSymbolProvider(model)
@@ -236,4 +271,51 @@ class InstantiatorTest {
        }
        writer.compileAndTest()
    }

    @Test
    fun `generate struct with missing required members`() {
        val structure = model.lookup<StructureShape>("com.test#MyStructRequired")
        val inner = model.lookup<StructureShape>("com.test#Inner")
        val nestedStruct = model.lookup<StructureShape>("com.test#NestedStruct")
        val union = model.lookup<UnionShape>("com.test#NestedUnion")
        val sut = Instantiator(symbolProvider, model, runtimeConfig, CodegenTarget.SERVER)
        val data = Node.parse("{}")
        val writer = RustWriter.forModule("model")
        structure.renderWithModelBuilder(model, symbolProvider, writer)
        inner.renderWithModelBuilder(model, symbolProvider, writer)
        nestedStruct.renderWithModelBuilder(model, symbolProvider, writer)
        UnionGenerator(model, symbolProvider, writer, union).render()
        writer.test {
            writer.withBlock("let result = ", ";") {
                sut.render(this, structure, data, Instantiator.defaultContext().copy(defaultsForRequiredFields = true))
            }
            writer.write(
                """
                use std::collections::HashMap;
                use aws_smithy_types::{DateTime, Document};

                let expected = MyStructRequired {
                    str: Some("".into()),
                    primitive_int: 0,
                    int: Some(0),
                    ts: Some(DateTime::from_secs(0)),
                    byte: Some(0),
                    union: Some(NestedUnion::Struct(NestedStruct {
                        str: Some("".into()),
                        num: Some(0),
                    })),
                    structure: Some(NestedStruct {
                        str: Some("".into()),
                        num: Some(0),
                    }),
                    list: Some(vec![]),
                    map: Some(HashMap::new()),
                    doc: Some(Document::Object(HashMap::new())),
                };
                assert_eq!(result, expected);
                """,
            )
        }
        writer.compileAndTest()
    }
}
+8 −2
Original line number Diff line number Diff line
@@ -34,7 +34,10 @@ service Config {
        uri: "/",
        body: "{\"as\": 5, \"async\": true}",
        bodyMediaType: "application/json",
        headers: {"Content-Type": "application/x-amz-json-1.1"}
        headers: {
            "Content-Type": "application/x-amz-json-1.1",
            "X-Amz-Target": "Config.ReservedWordsAsMembers",
        },
    }
])
operation ReservedWordsAsMembers {
@@ -78,7 +81,10 @@ structure Type {
        uri: "/",
        body: "{\"regular_string\": \"hello!\"}",
        bodyMediaType: "application/json",
        headers: {"Content-Type": "application/x-amz-json-1.1"}
        headers: {
            "Content-Type": "application/x-amz-json-1.1",
            "X-Amz-Target": "Config.StructureNamePunning",
        },
    }
])
operation StructureNamePunning {
+5 −8
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.client.rustlang.RustModule
import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.DefaultPublicModules
import software.amazon.smithy.rust.codegen.client.smithy.RustCrate
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport
@@ -37,15 +38,11 @@ open class ServerServiceGenerator(
     * which assigns a symbol location to each shape.
     */
    fun render() {
        for (operation in operations) {
            rustCrate.useShapeWriter(operation) { operationWriter ->
                protocolGenerator.serverRenderOperation(
                    operationWriter,
                    operation,
                )
                ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, operation, operationWriter)
                    .render()
        rustCrate.withModule(DefaultPublicModules["operation"]!!) { writer ->
            ServerProtocolTestGenerator(coreCodegenContext, protocolSupport, protocolGenerator).render(writer)
        }

        for (operation in operations) {
            if (operation.errors.isNotEmpty()) {
                rustCrate.withModule(RustModule.Error) { writer ->
                    renderCombinedErrors(writer, operation)
Loading