From 26cb37afb0087d1dee401fafc907c112233bbaaa Mon Sep 17 00:00:00 2001 From: Burak Date: Tue, 7 Mar 2023 10:04:56 +0000 Subject: [PATCH] Python: Add tests and fix the issues with `Timestamp` and `ByteStream` (#2431) * Add `timestamp` type test and fix the conversion error * Add some tests for `ByteStream` and fix async issues * Use `__anext__` method instead of `anext` --------- Co-authored-by: Matteo Bigoi <1781140+crisidev@users.noreply.github.com> --- .../serialize/JsonSerializerGenerator.kt | 8 +- codegen-server/python/build.gradle.kts | 3 + .../generators/PythonServerTypesTest.kt | 91 +++++++++++ .../aws-smithy-http-server-python/Cargo.toml | 10 +- .../src/pytests/bytestream.rs | 151 ++++++++++++++++++ .../src/pytests/harness.rs | 11 ++ .../src/types.rs | 82 +++++----- 7 files changed, 306 insertions(+), 50 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/pytests/bytestream.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/pytests/harness.rs diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 89dd945d3..c851437dc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -38,7 +38,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section -import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError @@ -179,7 +178,6 @@ class JsonSerializerGenerator( private val serializerUtil = SerializerUtil(model) private val operationSerModule = RustModule.private("operation_ser") private val jsonSerModule = RustModule.private("json_ser") - private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) /** * Reusable structure serializer implementation that can be used to generate serializing code for @@ -407,11 +405,7 @@ class JsonSerializerGenerator( val timestampFormat = httpBindingResolver.timestampFormat(context.shape, HttpLocation.DOCUMENT, EPOCH_SECONDS) val timestampFormatType = RuntimeType.timestampFormat(runtimeConfig, timestampFormat) - rustTemplate( - "$writer.date_time(${value.asRef()}#{ConvertInto:W}, #{FormatType})?;", - "FormatType" to timestampFormatType, - "ConvertInto" to typeConversionGenerator.convertViaInto(target), - ) + rust("$writer.date_time(${value.asRef()}, #T)?;", timestampFormatType) } is CollectionShape -> jsonArrayWriter(context) { arrayName -> diff --git a/codegen-server/python/build.gradle.kts b/codegen-server/python/build.gradle.kts index bbc30ac6e..ac792966a 100644 --- a/codegen-server/python/build.gradle.kts +++ b/codegen-server/python/build.gradle.kts @@ -27,6 +27,9 @@ dependencies { implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + + // `smithy.framework#ValidationException` is defined here, which is used in `PythonServerTypesTest`. + testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" } diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt index be510db87..e17c058d7 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt @@ -144,4 +144,95 @@ internal class PythonServerTypesTest { cargoTest(testDir) } + + @Test + fun `timestamp type`() { + val model = """ + namespace test + + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service Service { + operations: [ + Echo, + ], + } + + @http(method: "POST", uri: "/echo") + operation Echo { + input: EchoInput, + output: EchoOutput, + errors: [ValidationException], + } + + structure EchoInput { + @required + value: Timestamp, + opt_value: Timestamp, + } + + structure EchoOutput { + @required + value: Timestamp, + opt_value: Timestamp, + } + """.asSmithyModel() + + val (pluginCtx, testDir) = generatePythonServerPluginContext(model) + executePythonServerCodegenVisitor(pluginCtx) + + val writer = RustWriter.forModule("service") + writer.tokioTest("timestamp_type") { + rust( + """ + use tower::Service as _; + use pyo3::{types::IntoPyDict, IntoPy, Python}; + use hyper::{Body, Request, body}; + use crate::{input, output, python_types}; + + pyo3::prepare_freethreaded_python(); + + let mut service = Service::builder_without_plugins() + .echo(|input: input::EchoInput| async { + Ok(Python::with_gil(|py| { + let globals = [ + ("EchoOutput", py.get_type::()), + ("DateTime", py.get_type::()), + ].into_py_dict(py); + let locals = [("input", input.into_py(py))].into_py_dict(py); + + py.run("assert input.value.secs() == 1676298520", Some(globals), Some(locals)).unwrap(); + py.run("output = EchoOutput(value=input.value, opt_value=DateTime.from_secs(1677771678))", Some(globals), Some(locals)).unwrap(); + + locals + .get_item("output") + .unwrap() + .extract::() + .unwrap() + })) + }) + .build() + .unwrap(); + + let req = Request::builder() + .method("POST") + .uri("/echo") + .body(Body::from("{\"value\":1676298520}")) + .unwrap(); + let res = service.call(req).await.unwrap(); + assert!(res.status().is_success()); + let body = body::to_bytes(res.into_body()).await.unwrap(); + let body = std::str::from_utf8(&body).unwrap(); + assert!(body.contains("\"value\":1676298520")); + assert!(body.contains("\"opt_value\":1677771678")); + """.trimIndent(), + ) + } + + testDir.resolve("src/service.rs").appendText(writer.toString()) + + cargoTest(testDir) + } } diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index 78bcac926..b0cba12b7 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -51,14 +51,22 @@ pretty_assertions = "1" futures-util = "0.3" tower-test = "0.4" tokio-test = "0.4" -pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] } +pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime", "unstable-streams"] } rcgen = "0.10.0" hyper-rustls = { version = "0.23.1", features = ["http2"] } +# PyO3 Asyncio tests cannot use Cargo's default testing harness because `asyncio` +# wants to control the main thread. So we need to use testing harness provided by `pyo3_asyncio` +# for the async Python tests. For more detail see: +# https://docs.rs/pyo3-asyncio/0.18.0/pyo3_asyncio/testing/index.html#pyo3-asyncio-testing-utilities [[test]] name = "middleware_tests" path = "src/middleware/pytests/harness.rs" harness = false +[[test]] +name = "python_tests" +path = "src/pytests/harness.rs" +harness = false [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-http-server-python/src/pytests/bytestream.rs b/rust-runtime/aws-smithy-http-server-python/src/pytests/bytestream.rs new file mode 100644 index 000000000..a5d7ab0c9 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/pytests/bytestream.rs @@ -0,0 +1,151 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::io; + +use futures::StreamExt; +use futures_util::stream; +use hyper::Body; +use pyo3::{prelude::*, py_run}; + +use aws_smithy_http::body::SdkBody; +use aws_smithy_http_server_python::types::ByteStream; + +#[pyo3_asyncio::tokio::test] +fn consuming_stream_on_python_synchronously() -> PyResult<()> { + let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]); + Python::with_gil(|py| { + let bytestream = bytestream.into_py(py); + py_run!( + py, + bytestream, + r#" +assert next(bytestream) == b"hello" +assert next(bytestream) == b" " +assert next(bytestream) == b"world" + +try: + next(bytestream) + assert False, "iteration should stop by now" +except StopIteration: + pass +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +fn consuming_stream_on_python_synchronously_with_loop() -> PyResult<()> { + let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]); + Python::with_gil(|py| { + let bytestream = bytestream.into_py(py); + py_run!( + py, + bytestream, + r#" +total = [] +for chunk in bytestream: + total.append(chunk) + +assert total == [b"hello", b" ", b"world"] +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +fn consuming_stream_on_python_asynchronously() -> PyResult<()> { + let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]); + Python::with_gil(|py| { + let bytestream = bytestream.into_py(py); + py_run!( + py, + bytestream, + r#" +import asyncio + +async def main(bytestream): + assert await bytestream.__anext__() == b"hello" + assert await bytestream.__anext__() == b" " + assert await bytestream.__anext__() == b"world" + + try: + await bytestream.__anext__() + assert False, "iteration should stop by now" + except StopAsyncIteration: + pass + +asyncio.run(main(bytestream)) +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +fn consuming_stream_on_python_asynchronously_with_loop() -> PyResult<()> { + let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]); + Python::with_gil(|py| { + let bytestream = bytestream.into_py(py); + py_run!( + py, + bytestream, + r#" +import asyncio + +async def main(bytestream): + total = [] + async for chunk in bytestream: + total.append(chunk) + assert total == [b"hello", b" ", b"world"] + +asyncio.run(main(bytestream)) +"# + ); + Ok(()) + }) +} + +#[pyo3_asyncio::tokio::test] +async fn streaming_back_to_rust_from_python() -> PyResult<()> { + let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]); + let py_stream = Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +async def handler(bytestream): + async for chunk in bytestream: + yield "🐍 " + chunk.decode("utf-8") + yield "Hello from Python!" +"#, + "", + "", + )?; + let handler = module.getattr("handler")?; + let output = handler.call1((bytestream,))?; + Ok::<_, PyErr>(pyo3_asyncio::tokio::into_stream_v2(output)) + })??; + + let mut py_stream = py_stream.map(|v| Python::with_gil(|py| v.extract::(py).unwrap())); + + assert_eq!(py_stream.next().await, Some("🐍 hello".to_string())); + assert_eq!(py_stream.next().await, Some("🐍 ".to_string())); + assert_eq!(py_stream.next().await, Some("🐍 world".to_string())); + assert_eq!( + py_stream.next().await, + Some("Hello from Python!".to_string()) + ); + assert_eq!(py_stream.next().await, None); + + Ok(()) +} + +fn streaming_bytestream_from_vec(chunks: Vec<&'static str>) -> ByteStream { + let stream = stream::iter(chunks.into_iter().map(|v| Ok::<_, io::Error>(v))); + let body = Body::wrap_stream(stream); + ByteStream::new(SdkBody::from(body)) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/pytests/harness.rs b/rust-runtime/aws-smithy-http-server-python/src/pytests/harness.rs new file mode 100644 index 000000000..069e4e987 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/pytests/harness.rs @@ -0,0 +1,11 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[pyo3_asyncio::tokio::main] +async fn main() -> pyo3::PyResult<()> { + pyo3_asyncio::testing::main().await +} + +mod bytestream; diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index 1ae44a043..32269bdac 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -4,6 +4,16 @@ */ //! Python wrapped types from aws-smithy-types and aws-smithy-http. +//! +//! ## `Deref` hacks for Json serializer +//! [aws_smithy_json::serialize::JsonValueWriter] expects references to the types +//! from [aws_smithy_types] (for example [aws_smithy_json::serialize::JsonValueWriter::document()] +//! expects `&aws_smithy_types::Document`). In order to make +//! [aws_smithy_json::serialize::JsonValueWriter] happy, we implement `Deref` traits for +//! Python types to their Rust counterparts (for example +//! `impl Deref for Document` and that allows `&Document` to +//! get coerced to `&aws_smithy_types::Document`). This is a hack, we should ideally handle this +//! in `JsonSerializerGenerator.kt` but it's not easy to do it with our current Kotlin structure. use std::{ collections::HashMap, @@ -16,10 +26,9 @@ use std::{ use bytes::Bytes; use pyo3::{ - exceptions::{PyRuntimeError, PyStopIteration, PyTypeError}, + exceptions::{PyRuntimeError, PyStopAsyncIteration, PyTypeError}, iter::IterNextOutput, prelude::*, - pyclass::IterANextOutput, }; use tokio::sync::Mutex; use tokio_stream::StreamExt; @@ -285,15 +294,11 @@ impl From for DateTime { } } -impl From for aws_smithy_types::DateTime { - fn from(other: DateTime) -> aws_smithy_types::DateTime { - other.0 - } -} +impl Deref for DateTime { + type Target = aws_smithy_types::DateTime; -impl<'date> From<&'date DateTime> for &'date aws_smithy_types::DateTime { - fn from(other: &'date DateTime) -> &'date aws_smithy_types::DateTime { - &other.0 + fn deref(&self) -> &Self::Target { + &self.0 } } @@ -353,12 +358,13 @@ impl futures::stream::Stream for ByteStream { /// Return a new data chunk from the stream. async fn yield_data_chunk( body: Arc>, -) -> PyResult { +) -> PyResult> { let mut stream = body.lock().await; - match stream.next().await { - Some(bytes) => bytes.map_err(|e| PyRuntimeError::new_err(e.to_string())), - None => Err(PyStopIteration::new_err("stream exhausted")), - } + stream + .next() + .await + .transpose() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) } impl ByteStream { @@ -444,14 +450,9 @@ impl ByteStream { let body = slf.0.clone(); let data_chunk = futures::executor::block_on(yield_data_chunk(body)); match data_chunk { - Ok(data_chunk) => Ok(IterNextOutput::Yield(data_chunk.into_py(slf.py()))), - Err(e) => { - if e.is_instance_of::(slf.py()) { - Ok(IterNextOutput::Return(slf.py().None())) - } else { - Err(e) - } - } + Ok(Some(data_chunk)) => Ok(IterNextOutput::Yield(data_chunk.into_py(slf.py()))), + Ok(None) => Ok(IterNextOutput::Return(slf.py().None())), + Err(e) => Err(e), } } @@ -463,29 +464,30 @@ impl ByteStream { } /// Return an awaitable resulting in a next value of the iterator or raise a StopAsyncIteration - /// exception when the iteration is over. PyO3 allows to raise the correct exception using the enum - /// [IterANextOutput](pyo3::pyclass::IterANextOutput). + /// exception when the iteration is over. /// /// To get the next value of the iterator, the `Arc` inner stream is cloned and the Rust call /// to `next()` is converted into an awaitable Python coroutine. /// /// More info: `` - pub fn __anext__(slf: PyRefMut) -> PyResult, PyObject>> { + /// + /// About the return type, we cannot use `IterANextOutput` because we don't know if we + /// have a next value or not until we call the `next` on the underlying stream which is + /// an async operation and it's awaited on the Python side. So we're returning + /// `StopAsyncIteration` inside the returned future lazily. + /// The reason for the extra `Option` wrapper is that PyO3 expects `__anext__` to return + /// either `Option` or `IterANextOutput` and fails to compile otherwise, so we're + /// using extra `Option` just to make PyO3 happy. + pub fn __anext__(slf: PyRefMut) -> PyResult> { let body = slf.0.clone(); - let data_chunk = pyo3_asyncio::tokio::local_future_into_py(slf.py(), async move { + let fut = pyo3_asyncio::tokio::future_into_py(slf.py(), async move { let data = yield_data_chunk(body).await?; - Ok(Python::with_gil(|py| data.into_py(py))) - }); - match data_chunk { - Ok(data_chunk) => Ok(IterANextOutput::Yield(data_chunk.into_py(slf.py()))), - Err(e) => { - if e.is_instance_of::(slf.py()) { - Ok(IterANextOutput::Return(slf.py().None())) - } else { - Err(e) - } + match data { + Some(data) => Ok(Python::with_gil(|py| data.into_py(py))), + None => Err(PyStopAsyncIteration::new_err("stream exhausted")), } - } + })?; + Ok(Some(fut.into())) } } @@ -550,10 +552,6 @@ impl FromPyObject<'_> for Document { } } -// TODO(PythonSerialization): Get rid of this hack. -// `JsonValueWriter::document` expects `&aws_smithy_types::Document` -// and this impl allows `&Document` to get coerced to `&aws_smithy_types::Document`. -// We should ideally handle this in `JsonSerializerGenerator.kt` but I'm not sure how hard it is. impl Deref for Document { type Target = aws_smithy_types::Document; -- GitLab