Unverified Commit 409b6b99 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Assert field-by-field equality in protocol tests (#337)

Protocol tests currently use `PartialEq`, but when we add streaming members, we can no longer compare equality synchronously & without
side effects. This diff removes `PartialEq` from output shapes (we may re add it later for non-streaming shapes) & instead asserts equality of each output member in turn.
parent 1d00f381
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -157,6 +157,9 @@ data class RustMetadata(
    fun withDerives(vararg newDerive: RuntimeType): RustMetadata =
        this.copy(derives = derives.copy(derives = derives.derives + newDerive))

    fun withoutDerives(vararg withoutDerives: RuntimeType) =
        this.copy(derives = derives.copy(derives = derives.derives - withoutDerives))

    private fun attributes(): List<Attribute> = additionalAttributes + derives

    fun renderAttributes(writer: RustWriter): RustMetadata {
@@ -231,7 +234,11 @@ sealed class Attribute {
     * Finally, any symbols listed will be imported when this attribute is rendered. This enables using attributes like
     * `#[serde(Serialize)]` where `Serialize` is actually a symbol that must be imported.
     */
    data class Custom(val annotation: String, val symbols: List<RuntimeType> = listOf(), val container: Boolean = false) : Attribute() {
    data class Custom(
        val annotation: String,
        val symbols: List<RuntimeType> = listOf(),
        val container: Boolean = false
    ) : Attribute() {
        override fun render(writer: RustWriter) {

            val bang = if (container) "!" else ""
+5 −1
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@ import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhaustive
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.smithy.RuntimeType.Companion.PartialEq
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait

/**
 * Default delegator to enable easily decorating another symbol provider.
@@ -72,7 +74,9 @@ class BaseSymbolMetadataProvider(base: RustSymbolProvider) : SymbolMetadataProvi
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        return containerDefault
        return containerDefault.letIf(structureShape.hasTrait(SyntheticOutputTrait::class.java)) {
            it.withoutDerives(PartialEq)
        }
    }

    override fun unionMeta(unionShape: UnionShape): RustMetadata {
+7 −3
Original line number Diff line number Diff line
@@ -261,13 +261,17 @@ class HttpProtocolTestGenerator(
            val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name
            rust("""let parsed = parsed.expect_err("should be error response");""")
            rustBlock("if let #TKind::$errorVariant(actual_error) = parsed.kind", errorSymbol) {
                write("assert_eq!(expected_output, actual_error);")
                rust("assert_eq!(expected_output, actual_error);")
            }
            rustBlock("else") {
                write("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
                rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
            }
        } else {
            write("assert_eq!(parsed.unwrap(), expected_output);")
            rust("let parsed = parsed.unwrap();")
            outputShape.members().forEach { member ->
                val memberName = protocolConfig.symbolProvider.toMemberName(member)
                rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""")
            }
        }
    }