Unverified Commit 12cf9160 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Update serde implementation to support out of range floats (#3825)

## Motivation and Context
Fix serde behavior to match generated code. This is important to avoid
loosing data during serialization, especially as out-of-range floats
often indicate an error.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent d8fbf47d
Loading
Loading
Loading
Loading
+35 −1
Original line number Diff line number Diff line
@@ -10,6 +10,8 @@ import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
@@ -209,7 +211,10 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
     */
    private fun serializeNumber(shape: NumberShape): RuntimeType {
        val numericType = SimpleShapes.getValue(shape::class)
        return RuntimeType.forInlineFun(
        return when (shape) {
            is FloatShape, is DoubleShape -> serializeFloat(shape)
            else ->
                RuntimeType.forInlineFun(
                    numericType.toString(),
                    PrimitiveShapesModule,
                ) {
@@ -218,6 +223,35 @@ class SerializeImplGenerator(private val codegenContext: CodegenContext) {
                    }
                }
        }
    }

    private fun serializeFloat(shape: NumberShape): RuntimeType {
        val numericType = SimpleShapes.getValue(shape::class)
        return RuntimeType.forInlineFun(
            numericType.toString(),
            PrimitiveShapesModule,
        ) {
            implSerializeConfigured(symbolBuilder(shape, numericType).build()) {
                rustTemplate(
                    """
                    if !self.settings.out_of_range_floats_as_strings {
                        return self.value.serialize(serializer)
                    }
                    if self.value.is_nan() {
                        serializer.serialize_str("NaN")
                    } else if *self.value == #{ty}::INFINITY {
                        serializer.serialize_str("Infinity")
                    } else if *self.value == #{ty}::NEG_INFINITY {
                        serializer.serialize_str("-Infinity")
                    } else {
                        self.value.serialize(serializer)
                    }
                    """,
                    "ty" to numericType,
                )
            }
        }
    }

    /**
     * Serialize a type that already implements `Serialize` directly via `value.serialize(serializer)`
+10 −5
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ object SupportStructures {
                {
                    use #{serde}::Serialize;
                    value
                        .serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: true })
                        .serialize_ref(&#{SerializationSettings}::redact_sensitive_fields())
                        .serialize(serializer)
                }
                """,
@@ -70,7 +70,7 @@ object SupportStructures {
                {
                    use #{serde}::Serialize;
                    value
                        .serialize_ref(&#{SerializationSettings} { redact_sensitive_fields: false })
                        .serialize_ref(&#{SerializationSettings}::leak_sensitive_fields())
                        .serialize(serializer)
                }
                """,
@@ -211,7 +211,6 @@ object SupportStructures {

    private fun serializationSettings() =
        RuntimeType.forInlineFun("SerializationSettings", supportModule) {
            // TODO(serde): Consider removing `derive(Default)`
            rustTemplate(
                """
                /// Settings for use when serializing structures
@@ -220,6 +219,12 @@ object SupportStructures {
                pub struct SerializationSettings {
                    /// Replace all sensitive fields with `<redacted>` during serialization
                    pub redact_sensitive_fields: bool,

                    /// Serialize Nan, infinity and negative infinity as strings.
                    ///
                    /// For protocols like JSON, this avoids the loss-of-information that occurs when these out-of-range values
                    /// are serialized as null.
                    pub out_of_range_floats_as_strings: bool,
                }

                impl SerializationSettings {
@@ -227,10 +232,10 @@ object SupportStructures {
                    ///
                    /// Note: This may alter the type of the serialized output and make it impossible to deserialize as
                    /// numerical fields will be replaced with strings.
                    pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true } }
                    pub const fn redact_sensitive_fields() -> Self { Self { redact_sensitive_fields: true, out_of_range_floats_as_strings: false } }

                    /// Preserve the contents of sensitive fields during serializing
                    pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false } }
                    pub const fn leak_sensitive_fields() -> Self { Self { redact_sensitive_fields: false, out_of_range_floats_as_strings: false } }
                }
                """,
            )
+27 −3
Original line number Diff line number Diff line
@@ -70,7 +70,9 @@ class SerdeDecoratorTest {
           blob: SensitiveBlob,
           constrained: Constrained,
           recursive: Recursive,
           map: EnumKeyedMap
           map: EnumKeyedMap,
           float: Float,
           double: Double
        }

        structure Constrained {
@@ -134,6 +136,8 @@ class SerdeDecoratorTest {
        structure Nested {
          @required
          int: Integer,
          float: Float,
          double: Double,
          sensitive: Timestamps,
          notSensitive: AlsoTimestamps,
          manyEnums: TestEnumList,
@@ -202,8 +206,12 @@ class SerdeDecoratorTest {
                            .e(Some(TestEnum::A))
                            .document(Some(Document::String("hello!".into())))
                            .blob(Some(Blob::new("hello")))
                            .float(Some(f32::INFINITY))
                            .double(Some(f64::NAN))
                            .nested(Some(Nested::builder()
                                .int(5)
                                .float(Some(f32::NEG_INFINITY))
                                .double(Some(f64::NEG_INFINITY))
                                .sensitive(Some(sensitive_map.clone()))
                                .not_sensitive(Some(sensitive_map))
                                .many_enums(Some(vec![TestEnum::A]))
@@ -274,6 +282,8 @@ class SerdeDecoratorTest {
        "e": "A",
        "nested": {
          "int": 5,
          "float": "-Infinity",
          "double": "-Infinity",
          "sensitive": {
            "a": "1970-01-01T00:00:00Z"
          },
@@ -289,7 +299,9 @@ class SerdeDecoratorTest {
          "enum": "B"
        },
        "document": "hello!",
        "blob": "aGVsbG8="
        "blob": "aGVsbG8=",
        "float": "Infinity",
        "double": "NaN"
    }""".replace("\\s".toRegex(), "")

    private val expectedRedacted =
@@ -298,6 +310,8 @@ class SerdeDecoratorTest {
        "e": "<redacted>",
        "nested": {
          "int": 5,
          "float": "-Infinity",
          "double": "-Infinity",
          "sensitive": {
            "a": "<redacted>"
          },
@@ -311,7 +325,9 @@ class SerdeDecoratorTest {
        },
        "union": "<redacted>",
        "document": "hello!",
        "blob": "<redacted>"
        "blob": "<redacted>",
        "float": "Infinity",
        "double": "NaN"
        }
        """.replace("\\s".toRegex(), "")

@@ -343,8 +359,12 @@ class SerdeDecoratorTest {
                                .e("A".into())
                                .document(Document::String("hello!".into()))
                                .blob(Blob::new("hello"))
                                .float(f32::INFINITY)
                                .double(f64::NAN)
                                .nested(Nested::builder()
                                    .int(5)
                                    .float(f32::NEG_INFINITY)
                                    .double(f64::NEG_INFINITY)
                                    .sensitive("a", DateTime::from(UNIX_EPOCH))
                                    .not_sensitive("a", DateTime::from(UNIX_EPOCH))
                                    .many_enums("A".into())
@@ -355,11 +375,15 @@ class SerdeDecoratorTest {
                                .build()
                                .unwrap();
                            let mut settings = #{crate}::serde::SerializationSettings::default();
                            settings.out_of_range_floats_as_strings = true;
                            let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
                            assert_eq!(serialized, ${expectedNoRedactions.dq()});
                            settings.redact_sensitive_fields = true;
                            let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
                            assert_eq!(serialized, ${expectedRedacted.dq()});
                            settings.out_of_range_floats_as_strings = false;
                            let serialized = #{serde_json}::to_string(&input.serialize_ref(&settings)).expect("failed to serialize");
                            assert_ne!(serialized, ${expectedRedacted.dq()});
                            """,
                            *codegenScope,
                        )