Commit a2679acf authored by Fahad Zubair's avatar Fahad Zubair
Browse files

Specifically, use float.is_nan() when comparing floats that may have NaN values.

parent 275f997d
Loading
Loading
Loading
Loading
+42 −21
Original line number Diff line number Diff line
@@ -420,7 +420,10 @@ fn try_json_eq(expected: &str, actual: &str) -> Result<(), ProtocolTestFailure>
/// This function recursively compares two CBOR values, correctly handling arrays and maps
/// according to the CBOR specification. Arrays are compared element-wise in order,
/// while maps are compared without considering the order of key-value pairs.
fn cbor_values_equal(a: &ciborium::Value, b: &ciborium::Value) -> Result<bool, ProtocolTestFailure> {
fn cbor_values_equal(
    a: &ciborium::Value,
    b: &ciborium::Value,
) -> Result<bool, ProtocolTestFailure> {
    let result = match (a, b) {
        (ciborium::Value::Array(a_array), ciborium::Value::Array(b_array)) => {
            // Both arrays should be equal in size.
@@ -429,7 +432,7 @@ fn cbor_values_equal(a: &ciborium::Value, b: &ciborium::Value) -> Result<bool, P
            a_array.iter().zip(b_array.iter()).try_fold(true, |acc, (a_elem, b_elem)| {
                cbor_values_equal(a_elem, b_elem).map(|equal| acc && equal)
            })?
        },
        }

        // Convert `ciborium::Map` to sorted `BTreeMap` and then compare the sorted maps.
        (ciborium::Value::Map(a_map), ciborium::Value::Map(b_map)) => {
@@ -438,16 +441,22 @@ fn cbor_values_equal(a: &ciborium::Value, b: &ciborium::Value) -> Result<bool, P

            if a_btree.len() != b_btree.len() {
                false
            }
            else {
            } else {
                // Each key in `a` should exist in `b`, and the values should match.
                a_btree.iter().try_fold(true, |acc, (a_key, a_value)|
                    b_btree.get(a_key)
                        .map(|b_value| cbor_values_equal(a_value, b_value).map(|equal| acc && equal))
                a_btree.iter().try_fold(true, |acc, (a_key, a_value)| {
                    b_btree
                        .get(a_key)
                        .map(|b_value| {
                            cbor_values_equal(a_value, b_value).map(|equal| acc && equal)
                        })
                        .unwrap_or(Ok(false))
                )?
                })?
            }
        }

        (ciborium::Value::Float(a_float), ciborium::Value::Float(b_float)) => {
            a_float == b_float || a_float.is_nan() && b_float.is_nan()
        }
        },

        _ => a == b,
    };
@@ -463,15 +472,19 @@ fn cbor_values_equal(a: &ciborium::Value, b: &ciborium::Value) -> Result<bool, P
/// this function converts the vector into a `BTreeMap`, which maintains
/// the entries in a sorted order based on the keys.
/// This allows for consistent, order-independent comparisons between maps.
fn ciborium_map_to_btreemap(cbor_map: &[(ciborium::Value, ciborium::Value)]) -> Result<std::collections::BTreeMap<&String, &ciborium::Value>, ProtocolTestFailure> {
fn ciborium_map_to_btreemap(
    cbor_map: &[(ciborium::Value, ciborium::Value)],
) -> Result<std::collections::BTreeMap<&String, &ciborium::Value>, ProtocolTestFailure> {
    let mut btree = std::collections::BTreeMap::new();
    for (key, value) in cbor_map {
        match key {
            ciborium::Value::Text(key_str) => btree.insert(key_str, value),
            _ => return Err(ProtocolTestFailure::InvalidBodyFormat {
            _ => {
                return Err(ProtocolTestFailure::InvalidBodyFormat {
                    expected: "a text key as map entry".to_string(),
                    found: format!("{:?}", key),
                })
            }
        };
    }
    Ok(btree)
@@ -486,10 +499,12 @@ fn try_cbor_eq<T: AsRef<[u8]> + Debug>(
        .expect("smithy protocol test `body` property is not properly base64 encoded");
    let expected_cbor_value: ciborium::Value =
        ciborium::from_reader(decoded.as_slice()).expect("expected value must be valid CBOR");
    let actual_cbor_value: ciborium::Value = ciborium::from_reader(actual_body.as_ref())
        .map_err(|e| ProtocolTestFailure::InvalidBodyFormat {
    let actual_cbor_value: ciborium::Value =
        ciborium::from_reader(actual_body.as_ref()).map_err(|e| {
            ProtocolTestFailure::InvalidBodyFormat {
                expected: "cbor".to_owned(),
                found: format!("{} {:?}", e, actual_body),
            }
        })?;
    let actual_body_base64 = base64_simd::STANDARD.encode_to_string(&actual_body);

@@ -681,9 +696,15 @@ mod tests {
            .expect("different ordering in CBOR decoded maps do not match");

        // The following is the CBOR representation of `{"a":[1,2,{"b":3, "c":4}]}`.
        let actual = [0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x62, 0x03, 0x61, 0x63, 0x04, 0xFF, 0xFF, 0xFF];
        let actual = [
            0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x62, 0x03, 0x61, 0x63, 0x04, 0xFF,
            0xFF, 0xFF,
        ];
        // The following is the base64-encoded CBOR representation of `{"a":[1,2,{"c":4, "b":3}]}`.
        let expected_base64 = base64_encode(&[0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x63, 0x04, 0x61, 0x62, 0x03, 0xFF, 0xFF, 0xFF]);
        let expected_base64 = base64_encode(&[
            0xBF, 0x61, 0x61, 0x9F, 0x01, 0x02, 0xBF, 0x61, 0x63, 0x04, 0x61, 0x62, 0x03, 0xFF,
            0xFF, 0xFF,
        ]);
        validate_body(actual, expected_base64.as_str(), MediaType::Cbor)
            .expect("different ordering in CBOR decoded maps do not match");