Unverified Commit b67e651f authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for document deserialization (#273)

parent 16b71cf9
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -355,9 +355,6 @@ class HttpProtocolTestGenerator(
        val AwsJson11 = "aws.protocoltests.json#JsonProtocol"
        val RestJson = "aws.protocoltests.restjson#RestJson"
        private val ExpectFail = setOf(
            // Document deserialization:
            FailingTest(AwsJson11, "PutAndGetInlineDocumentsInput", Action.Response),

            // Endpoint trait https://github.com/awslabs/smithy-rs/issues/197
            // This will also require running operations through the endpoint middleware (or moving endpoint middleware
            // into operation construction
+2 −6
Original line number Diff line number Diff line
@@ -126,14 +126,10 @@ class CustomSerializerGenerator(
        val fnName = serializerName(rustType, memberShape, "deser")
        return RuntimeType.forInlineFun(fnName, "serde_util") { writer ->
            deserializeFn(writer, fnName, symbol) {
                if (rustType.contains(document)) {
                    write("todo!()")
                } else {
                deserializer(rustType, memberShape)
            }
        }
    }
    }

    private fun rollSer(t: RustType, memberShape: MemberShape): Writable {
        return when (t) {
@@ -272,7 +268,7 @@ class CustomSerializerGenerator(
            document -> writable {
                when (serdeDirection) {
                    SerdeDirection.Serialize -> write("#T::SerDoc", RuntimeType.DocJson)
                    SerdeDirection.Deserialize -> TODO("doc deserialization is not supported")
                    SerdeDirection.Deserialize -> write("#T::DeserDoc", RuntimeType.DocJson)
                }
            }
            is RustType.Container -> writable { serdeContainerType(realType, memberShape) }
+108 −1
Original line number Diff line number Diff line
@@ -3,9 +3,13 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

use serde::{Serialize, Serializer};
use serde::{Serialize, Serializer, Deserialize, Deserializer};
use serde_json::Value;
use smithy_types::{Document, Number};
use serde::de::{Visitor, Error, SeqAccess, MapAccess};
use std::fmt::Formatter;
use std::fmt;
use std::collections::HashMap;

#[allow(unused)]
pub fn json_to_doc(json: Value) -> Document {
@@ -22,6 +26,7 @@ pub fn json_to_doc(json: Value) -> Document {
}

pub struct SerDoc<'a>(pub &'a Document);
pub struct DeserDoc(pub Document);

impl Serialize for SerDoc<'_> {
    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
@@ -44,6 +49,108 @@ impl Serialize for SerDoc<'_> {
    }
}

struct DocVisitor;
impl<'de> Visitor<'de> for DocVisitor {
    type Value = Document;

    fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
        write!(formatter, "Expecting a JSON-like document")
    }

    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Bool(v))
    }

    fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(serde_num_to_num(&serde_json::Number::from(v))))
    }

    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(Number::Float(v as _)))
    }

    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::Number(Number::Float(v as _)))
    }

    fn visit_char<E>(self, v: char) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::String(v.to_string()))
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> where
        E: Error, {
        Ok(Document::String(v.to_string()))

    }

    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, <A as SeqAccess<'de>>::Error> where
        A: SeqAccess<'de>, {
        let mut out: Vec<Document> = vec![];
        while let Some(next) = seq.next_element::<DeserDoc>()? {
            out.push(next.0);
        }
        Ok(Document::Array(out))
    }

    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, <A as MapAccess<'de>>::Error> where
        A: MapAccess<'de>, {
        let mut out: HashMap<String, Document> = HashMap::new();
        while let Some((k, v)) = map.next_entry::<String, DeserDoc>()? {
            out.insert(k, v.0);
        }
        Ok(Document::Object(out))
    }
}

impl<'de> Deserialize<'de> for DeserDoc {
    fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
        where
            D: Deserializer<'de>
    {
        Ok(DeserDoc(deserializer.deserialize_any(DocVisitor)?))
    }
}

fn serde_num_to_num(number: &serde_json::Number) -> smithy_types::Number {
    if number.is_f64() {
        smithy_types::Number::Float(number.as_f64().unwrap())