Commit 275f997d authored by Fahad Zubair's avatar Fahad Zubair
Browse files

Fix map comparison: implement order-insensitive equality for `ciborium::Value::Map`

parent 20ff1c7b
Loading
Loading
Loading
Loading
+91 −7
Original line number Diff line number Diff line
@@ -415,6 +415,68 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure>
    }
}

/// Compares two `ciborium::Value` instances for semantic equality.
///
/// This function recursively compares two CBOR values, correctly handling arrays and maps
/// according to the CBOR specification. Arrays are compared element-wise in order,
/// while maps are compared without considering the order of key-value pairs.
fn cbor_values_equal(a: &ciborium::Value, b: &ciborium::Value) -> Result<bool, ProtocolTestFailure> {
    let result = match (a, b) {
        (ciborium::Value::Array(a_array), ciborium::Value::Array(b_array)) => {
            // Both arrays should be equal in size.
            a_array.len() == b_array.len() &&
            // Compare arrays element-wise.
            a_array.iter().zip(b_array.iter()).try_fold(true, |acc, (a_elem, b_elem)| {
                cbor_values_equal(a_elem, b_elem).map(|equal| acc && equal)
            })?
        },

        // Convert `ciborium::Map` to sorted `BTreeMap` and then compare the sorted maps.
        (ciborium::Value::Map(a_map), ciborium::Value::Map(b_map)) => {
            let a_btree = ciborium_map_to_btreemap(a_map)?;
            let b_btree = ciborium_map_to_btreemap(b_map)?;

            if a_btree.len() != b_btree.len() {
                false
            }
            else {
                // Each key in `a` should exist in `b`, and the values should match.
                a_btree.iter().try_fold(true, |acc, (a_key, a_value)|
                    b_btree.get(a_key)
                        .map(|b_value| cbor_values_equal(a_value, b_value).map(|equal| acc && equal))
                        .unwrap_or(Ok(false))
                )?
            }
        },

        _ => a == b,
    };

    Ok(result)
}

/// Transforms a `ciborium::Value::Map` into a `BTreeMap<&String, &ciborium::Value>` sorted by keys.
///
/// CBOR maps (`Value::Map`) are internally represented as vectors of key-value pairs,
/// and their direct comparison is affected by the order of these pairs.
/// Since CBOR specification treats maps as unordered collections,
/// this function converts the vector into a `BTreeMap`, which maintains
/// the entries in a sorted order based on the keys.
/// This allows for consistent, order-independent comparisons between maps.
fn ciborium_map_to_btreemap(cbor_map: &[(ciborium::Value, ciborium::Value)]) -> Result<std::collections::BTreeMap<&String, &ciborium::Value>, ProtocolTestFailure> {
    let mut btree = std::collections::BTreeMap::new();
    for (key, value) in cbor_map {
        match key {
            ciborium::Value::Text(key_str) => btree.insert(key_str, value),
            _ => return Err(ProtocolTestFailure::InvalidBodyFormat {
                expected: "a text key as map entry".to_string(),
                found: format!("{:?}", key),
            })
        };
    }
    Ok(btree)
}

fn try_cbor_eq<T: AsRef<[u8]> + Debug>(
    actual_body: T,
    expected_body: &str,
@@ -422,16 +484,16 @@ fn try_cbor_eq<T: AsRef<[u8]> + Debug>(
    let decoded = base64_simd::STANDARD
        .decode_to_vec(expected_body)
        .expect("smithy protocol test `body` property is not properly base64 encoded");
    let expected_cbor_value: ciborium::value::Value =
    let expected_cbor_value: ciborium::Value =
        ciborium::from_reader(decoded.as_slice()).expect("expected value must be valid CBOR");
    let actual_cbor_value: ciborium::value::Value = ciborium::from_reader(actual_body.as_ref())
    let actual_cbor_value: ciborium::Value = ciborium::from_reader(actual_body.as_ref())
        .map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
            expected: "cbor".to_owned(),
            found: format!("{} {:?}", e, actual_body),
        })?;
    let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body);

    if expected_cbor_value != actual_cbor_value {
    if !cbor_values_equal(&expected_cbor_value, &actual_cbor_value)? {
        let expected_body_annotated_hex: String = cbor_diag::parse_bytes(&decoded)
            .expect("smithy protocol test `body` property is not valid CBOR")
            .to_hex();
@@ -601,14 +663,36 @@ mod tests {

    #[test]
    fn test_validate_cbor_body() {
        let base64_encode = |v : &[u8]| base64_simd::STANDARD.encode_to_string(v);

        // The following is the CBOR representation of `{"abc": 5 }`.
        let actual = [0xbf, 0x63, 0x61, 0x62, 0x63, 0x05, 0xff];
        // The following is the CBOR representation of `{"abc": 5 }` using a definite length map.
        let expected = [0xA1, 0x63, 0x61, 0x62, 0x63, 0x05];
        let expected_base64 = base64_simd::STANDARD.encode_to_string(expected);
        // The following is the base64-encoded CBOR representation of `{"abc": 5 }` using a definite length map.
        let expected_base64 = base64_encode(&[0xA1, 0x63, 0x61, 0x62, 0x63, 0x05]);

        validate_body(actual, expected_base64.as_str(), MediaType::Cbor)
            .expect("unexpected mismatch between CBOR definite and indefinite map encodings");

        // The following is the CBOR representation of `{"a":1, "b":2}`.
        let actual = [0xBF, 0x61, 0x61, 0x01, 0x61, 0x62, 0x02, 0xFF];
        // The following is the base64-encoded CBOR representation of `{"b":2, "a":1}`.
        let expected_base64 = base64_encode(&[0xBF, 0x61, 0x62, 0x02, 0x61, 0x61, 0x01, 0xFF]);
        validate_body(actual, expected_base64.as_str(), MediaType::Cbor)
            .expect("different ordering in CBOR decoded maps do not match");

        // The following is the CBOR representation of `{"a":[1,2,{"b":3, "c":4}]}`.
        let actual = [0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x62, 0x03, 0x61, 0x63, 0x04, 0xFF, 0xFF, 0xFF];
        // The following is the base64-encoded CBOR representation of `{"a":[1,2,{"c":4, "b":3}]}`.
        let expected_base64 = base64_encode(&[0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x63, 0x04, 0x61, 0x62, 0x03, 0xFF, 0xFF, 0xFF]);
        validate_body(actual, expected_base64.as_str(), MediaType::Cbor)
            .expect("different ordering in CBOR decoded maps do not match");

        // The following is the CBOR representation of `{"a":[1,2]}`.
        let actual = [0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xFF, 0xFF];
        // The following is the CBOR representation of `{"a":[2,1]}`.
        let expected_base64 = base64_encode(&[0xBF, 0x61, 0x61, 0x9F, 0x02, 0x01, 0xFF, 0xFF]);
        validate_body(actual, expected_base64.as_str(), MediaType::Cbor)
            .expect("expected base64-encoded CBOR value did not match");
            .expect_err("arrays in CBOR should follow strict ordering");
    }

    #[test]