Unverified Commit 26cb37af authored by Burak's avatar Burak Committed by GitHub
Browse files

Python: Add tests and fix the issues with `Timestamp` and `ByteStream` (#2431)



* Add `timestamp` type test and fix the conversion error

* Add some tests for `ByteStream` and fix async issues

* Use `__anext__` method instead of `anext`

---------

Co-authored-by: default avatarMatteo Bigoi <1781140+crisidev@users.noreply.github.com>
parent 7ce80320
Loading
Loading
Loading
Loading
+1 −7
Original line number Diff line number Diff line
@@ -38,7 +38,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
@@ -179,7 +178,6 @@ class JsonSerializerGenerator(
    private val serializerUtil = SerializerUtil(model)
    private val operationSerModule = RustModule.private("operation_ser")
    private val jsonSerModule = RustModule.private("json_ser")
    private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig)

    /**
     * Reusable structure serializer implementation that can be used to generate serializing code for
@@ -407,11 +405,7 @@ class JsonSerializerGenerator(
                val timestampFormat =
                    httpBindingResolver.timestampFormat(context.shape, HttpLocation.DOCUMENT, EPOCH_SECONDS)
                val timestampFormatType = RuntimeType.timestampFormat(runtimeConfig, timestampFormat)
                rustTemplate(
                    "$writer.date_time(${value.asRef()}#{ConvertInto:W}, #{FormatType})?;",
                    "FormatType" to timestampFormatType,
                    "ConvertInto" to typeConversionGenerator.convertViaInto(target),
                )
                rust("$writer.date_time(${value.asRef()}, #T)?;", timestampFormatType)
            }

            is CollectionShape -> jsonArrayWriter(context) { arrayName ->
+3 −0
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@ dependencies {
    implementation(project(":codegen-server"))
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")

    // `smithy.framework#ValidationException` is defined here, which is used in `PythonServerTypesTest`.
    testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion")
}

tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" }
+91 −0
Original line number Diff line number Diff line
@@ -144,4 +144,95 @@ internal class PythonServerTypesTest {

        cargoTest(testDir)
    }

    @Test
    fun `timestamp type`() {
        val model = """
            namespace test

            use aws.protocols#restJson1
            use smithy.framework#ValidationException

            @restJson1
            service Service {
                operations: [
                    Echo,
                ],
            }

            @http(method: "POST", uri: "/echo")
            operation Echo {
                input: EchoInput,
                output: EchoOutput,
                errors: [ValidationException],
            }

            structure EchoInput {
                @required
                value: Timestamp,
                opt_value: Timestamp,
            }

            structure EchoOutput {
                @required
                value: Timestamp,
                opt_value: Timestamp,
            }
        """.asSmithyModel()

        val (pluginCtx, testDir) = generatePythonServerPluginContext(model)
        executePythonServerCodegenVisitor(pluginCtx)

        val writer = RustWriter.forModule("service")
        writer.tokioTest("timestamp_type") {
            rust(
                """
                use tower::Service as _;
                use pyo3::{types::IntoPyDict, IntoPy, Python};
                use hyper::{Body, Request, body};
                use crate::{input, output, python_types};
                
                pyo3::prepare_freethreaded_python();
                
                let mut service = Service::builder_without_plugins()
                    .echo(|input: input::EchoInput| async {
                        Ok(Python::with_gil(|py| {
                            let globals = [
                                ("EchoOutput", py.get_type::<output::EchoOutput>()),
                                ("DateTime", py.get_type::<python_types::DateTime>()),
                            ].into_py_dict(py);
                            let locals = [("input", input.into_py(py))].into_py_dict(py);
            
                            py.run("assert input.value.secs() == 1676298520", Some(globals), Some(locals)).unwrap();
                            py.run("output = EchoOutput(value=input.value, opt_value=DateTime.from_secs(1677771678))", Some(globals), Some(locals)).unwrap();
                            
                            locals
                                .get_item("output")
                                .unwrap()
                                .extract::<output::EchoOutput>()
                                .unwrap()
                        }))
                    })
                    .build()
                    .unwrap();

                let req = Request::builder()
                    .method("POST")
                    .uri("/echo")
                    .body(Body::from("{\"value\":1676298520}"))
                    .unwrap();
                let res = service.call(req).await.unwrap();
                assert!(res.status().is_success());
                let body = body::to_bytes(res.into_body()).await.unwrap();
                let body = std::str::from_utf8(&body).unwrap();
                assert!(body.contains("\"value\":1676298520"));
                assert!(body.contains("\"opt_value\":1677771678"));
                """.trimIndent(),
            )
        }

        testDir.resolve("src/service.rs").appendText(writer.toString())

        cargoTest(testDir)
    }
}
+9 −1
Original line number Diff line number Diff line
@@ -51,14 +51,22 @@ pretty_assertions = "1"
futures-util = "0.3"
tower-test = "0.4"
tokio-test = "0.4"
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime"] }
pyo3-asyncio = { version = "0.17.0", features = ["testing", "attributes", "tokio-runtime", "unstable-streams"] }
rcgen = "0.10.0"
hyper-rustls = { version = "0.23.1", features = ["http2"] }

# PyO3 Asyncio tests cannot use Cargo's default testing harness because `asyncio`
# wants to control the main thread. So we need to use testing harness provided by `pyo3_asyncio`
# for the async Python tests. For more detail see: 
# https://docs.rs/pyo3-asyncio/0.18.0/pyo3_asyncio/testing/index.html#pyo3-asyncio-testing-utilities
[[test]]
name = "middleware_tests"
path = "src/middleware/pytests/harness.rs"
harness = false
[[test]]
name = "python_tests"
path = "src/pytests/harness.rs"
harness = false

[package.metadata.docs.rs]
all-features = true
+151 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use std::io;

use futures::StreamExt;
use futures_util::stream;
use hyper::Body;
use pyo3::{prelude::*, py_run};

use aws_smithy_http::body::SdkBody;
use aws_smithy_http_server_python::types::ByteStream;

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_synchronously() -> PyResult<()> {
    let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
    Python::with_gil(|py| {
        let bytestream = bytestream.into_py(py);
        py_run!(
            py,
            bytestream,
            r#"
assert next(bytestream) == b"hello"
assert next(bytestream) == b" "
assert next(bytestream) == b"world"

try:
    next(bytestream)
    assert False, "iteration should stop by now"
except StopIteration:
    pass
"#
        );
        Ok(())
    })
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_synchronously_with_loop() -> PyResult<()> {
    let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
    Python::with_gil(|py| {
        let bytestream = bytestream.into_py(py);
        py_run!(
            py,
            bytestream,
            r#"
total = []
for chunk in bytestream:
    total.append(chunk)

assert total == [b"hello", b" ", b"world"]
"#
        );
        Ok(())
    })
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_asynchronously() -> PyResult<()> {
    let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
    Python::with_gil(|py| {
        let bytestream = bytestream.into_py(py);
        py_run!(
            py,
            bytestream,
            r#"
import asyncio

async def main(bytestream):
    assert await bytestream.__anext__() == b"hello"
    assert await bytestream.__anext__() == b" "
    assert await bytestream.__anext__() == b"world"

    try:
        await bytestream.__anext__()
        assert False, "iteration should stop by now"
    except StopAsyncIteration:
        pass

asyncio.run(main(bytestream))
"#
        );
        Ok(())
    })
}

#[pyo3_asyncio::tokio::test]
fn consuming_stream_on_python_asynchronously_with_loop() -> PyResult<()> {
    let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
    Python::with_gil(|py| {
        let bytestream = bytestream.into_py(py);
        py_run!(
            py,
            bytestream,
            r#"
import asyncio

async def main(bytestream):
    total = []
    async for chunk in bytestream:
        total.append(chunk)
    assert total == [b"hello", b" ", b"world"]

asyncio.run(main(bytestream))
"#
        );
        Ok(())
    })
}

#[pyo3_asyncio::tokio::test]
async fn streaming_back_to_rust_from_python() -> PyResult<()> {
    let bytestream = streaming_bytestream_from_vec(vec!["hello", " ", "world"]);
    let py_stream = Python::with_gil(|py| {
        let module = PyModule::from_code(
            py,
            r#"
async def handler(bytestream):
    async for chunk in bytestream:
        yield "🐍 " + chunk.decode("utf-8")
    yield "Hello from Python!"
"#,
            "",
            "",
        )?;
        let handler = module.getattr("handler")?;
        let output = handler.call1((bytestream,))?;
        Ok::<_, PyErr>(pyo3_asyncio::tokio::into_stream_v2(output))
    })??;

    let mut py_stream = py_stream.map(|v| Python::with_gil(|py| v.extract::<String>(py).unwrap()));

    assert_eq!(py_stream.next().await, Some("🐍 hello".to_string()));
    assert_eq!(py_stream.next().await, Some("🐍  ".to_string()));
    assert_eq!(py_stream.next().await, Some("🐍 world".to_string()));
    assert_eq!(
        py_stream.next().await,
        Some("Hello from Python!".to_string())
    );
    assert_eq!(py_stream.next().await, None);

    Ok(())
}

fn streaming_bytestream_from_vec(chunks: Vec<&'static str>) -> ByteStream {
    let stream = stream::iter(chunks.into_iter().map(|v| Ok::<_, io::Error>(v)));
    let body = Body::wrap_stream(stream);
    ByteStream::new(SdkBody::from(body))
}
Loading