Unverified Commit 6011a0fd authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Upgrade Smithy to 1.4.0 & fix Instantiator (#48)

Upgrade Smithy to 1.4.0. This enables us to re-enable the invalid protocol tests for 1.3.0. Along the way, `Instantiator` was refactored to cleanup its handling of `Optional` shapes.
parent e1714596
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -21,6 +21,26 @@ import software.amazon.smithy.utils.CodeWriter
import java.util.function.BiFunction

fun <T : CodeWriter> T.withBlock(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    block: T.() -> Unit
): T {
    return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block)
}

/**
 * Write a block to the writer.
 * If [conditional] is true, the [textBeforeNewLine], followed by [block], followed by [textAfterNewLine]
 * If [conditional] is false, only [block] is written.
 * This enables conditionally wrapping a block in a prefix/suffix, eg.
 *
 * ```
 * writer.withBlock("Some(", ")", conditional = symbol.isOptional()) {
 *      write("symbolValue")
 * }
 * ```
 */
fun <T : CodeWriter> T.conditionalBlock(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    conditional: Boolean = true,
+2 −6
Original line number Diff line number Diff line
@@ -27,11 +27,7 @@ class HttpProtocolTestGenerator(
    // TODO: remove these once Smithy publishes fixes.
    // These tests are not even attempted to be compiled
    val DisableTests = setOf(
        "RestJsonListsSerializeNull",
        "AwsJson11MapsSerializeNullValues",
        "AwsJson11ListsSerializeNull",
        "RestJsonSerializesNullMapValues",
        // This test is flake because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
        // This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37
        "AwsJson11Enums"
    )

@@ -104,7 +100,7 @@ class HttpProtocolTestGenerator(
        }
        testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") {
            writeInline("let input =")
            instantiator.render(httpRequestTestCase.params, inputShape, this)
            instantiator.render(this, inputShape, httpRequestTestCase.params)
            write(";")
            write("let http_request = input.build_http_request().body(()).unwrap();")
            with(httpRequestTestCase) {
+35 −17
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.node.StringNode
@@ -23,6 +24,7 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.conditionalBlock
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
@@ -42,7 +44,7 @@ class Instantiator(
    private val runtimeConfig: RuntimeConfig
) {

    fun render(arg: Node, shape: Shape, writer: RustWriter) {
    fun render(writer: RustWriter, shape: Shape, arg: Node) {
        when (shape) {
            // Compound Shapes
            is StructureShape -> renderStructure(writer, shape, arg as ObjectNode)
@@ -53,6 +55,9 @@ class Instantiator(
            is MapShape -> renderMap(writer, shape, arg as ObjectNode)
            is SetShape -> renderSet(writer, shape, arg as ArrayNode)

            // Members, supporting potentially optional members
            is MemberShape -> renderMember(writer, shape, arg)

            // Wrapped Shapes
            is TimestampShape -> writer.write(
                "\$T::from_epoch_seconds(${(arg as NumberNode).value})",
@@ -77,15 +82,37 @@ class Instantiator(
        }
    }

    /**
     * If the shape is optional: `Some(inner)` or `None`
     * otherwise: `inner`
     */
    private fun renderMember(
        writer: RustWriter,
        shape: MemberShape,
        arg: Node
    ) {
        val target = model.expectShape(shape.target)
        val symbol = symbolProvider.toSymbol(shape)
        if (arg is NullNode) {
            check(
                symbol.isOptional()
            ) { "A null node was provided for $shape but the symbol was not optional. This is invalid input data." }
            writer.write("None")
        } else {
            writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) {
                render(this, target, arg)
            }
        }
    }

    private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode) {
        if (symbolProvider.toSymbol(shape).rustType() is RustType.HashSet) {
            if (!data.isEmpty) {
                writer.rustBlock("") {
                    write("let mut ret = \$T::new();", RuntimeType.HashSet)
                    val valueShape = shape.member.let { model.expectShape(it.target) }
                    data.forEach { v ->
                        withBlock("ret.insert(", ");") {
                            render(v, valueShape, this)
                            renderMember(this, shape.member, v)
                        }
                    }
                    write("ret")
@@ -115,10 +142,9 @@ class Instantiator(
        if (data.members.isNotEmpty()) {
            writer.rustBlock("") {
                write("let mut ret = \$T::new();", RuntimeType.HashMap)
                val valueShape = shape.value.let { model.expectShape(it.target) }
                data.members.forEach { (k, v) ->
                    withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") {
                        render(v, valueShape, this)
                        renderMember(this, shape.value, v)
                    }
                }
                write("ret")
@@ -148,7 +174,7 @@ class Instantiator(
        writer.write("\$T::${memberName.toPascalCase()}", unionSymbol)
        // unions should specify exactly one member
        writer.withBlock("(", ")") {
            render(variant.value, member, this)
            render(this, member, variant.value)
        }
    }

@@ -162,17 +188,9 @@ class Instantiator(
        shape: CollectionShape,
        data: ArrayNode
    ) {
        val member = model.expectShape(shape.member.target)
        val memberSymbol = symbolProvider.toSymbol(shape.member)
        writer.withBlock("vec![", "]") {
            data.elements.forEach {
                if (it.isNullNode) {
                    write("None")
                } else {
                    withBlock("Some(", ")", conditional = memberSymbol.isOptional()) {
                        render(it, member, this)
                    }
                }
            data.elements.forEach { v ->
                renderMember(this, shape.member, v)
                write(",")
            }
        }
@@ -209,7 +227,7 @@ class Instantiator(
            val func = symbolProvider.toMemberName(memberShape)
            if (!value.isNullNode) {
                writer.withBlock(".$func(", ")") {
                    render(value, targetShape, this)
                    render(this, targetShape, value)
                }
            }
        }
+2 −2
Original line number Diff line number Diff line
@@ -12,9 +12,9 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.conditionalBlock
import software.amazon.smithy.rust.codegen.lang.render
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
@@ -158,7 +158,7 @@ class StructureGenerator(
            }

            rustBlock("pub fn build(self) -> $returnType", structureSymbol) {
                withBlock("Ok(", ")", conditional = fallibleBuilder) {
                conditionalBlock("Ok(", ")", conditional = fallibleBuilder) {
                    rustBlock("\$T", structureSymbol) {
                        members.forEach { member ->
                            val memberName = symbolProvider.toMemberName(member)
+11 −7
Original line number Diff line number Diff line
@@ -80,7 +80,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            writer.withBlock("let result = ", ";") {
                sut.render(data, union, this)
                sut.render(this, union, data)
            }
            writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));")
        }
@@ -103,7 +103,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            writer.withBlock("let result = ", ";") {
                sut.render(data, structure, this)
                sut.render(this, structure, data)
            }
            writer.write("assert_eq!(result.bar, 10);")
            writer.write("assert_eq!(result.foo.unwrap(), \"hello\");")
@@ -129,7 +129,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            withBlock("let result = ", ";") {
                sut.render(data, structure, this)
                sut.render(this, structure, data)
            }
            rust(
                """
@@ -160,7 +160,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            writer.withBlock("let result = ", ";") {
                sut.render(data, model.lookup("com.test#MyList"), writer)
                sut.render(writer, model.lookup("com.test#MyList"), data)
            }
            writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""")
        }
@@ -182,7 +182,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            writer.withBlock("let result = ", ";") {
                sut.render(data, model.lookup("com.test#MySparseList"), writer)
                sut.render(writer, model.lookup("com.test#MySparseList"), data)
            }
            writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""")
        }
@@ -206,7 +206,7 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            writer.withBlock("let result = ", ";") {
                sut.render(data, model.lookup("com.test#NestedMap"), writer)
                sut.render(writer, model.lookup("com.test#NestedMap"), data)
            }
            writer.write(
                """
@@ -229,7 +229,11 @@ class InstantiatorTest {
        writer.write("#[test]")
        writer.rustBlock("fn test_blob()") {
            withBlock("let blob = ", ";") {
                sut.render(StringNode.parse("foo".dq()), BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), this)
                sut.render(
                    this,
                    BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(),
                    StringNode.parse("foo".dq())
                )
            }
            write("assert_eq!(std::str::from_utf8(blob.as_ref()).unwrap(), \"foo\");")
        }
Loading