Unverified Commit 5c7b50dd authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Fix Rust 1.82 compilation errors by upgrading PyO3 from 0.18 to 0.20 (#4146)

## PyO3 Upgrade to Fix Compatibility with Rust 1.82

This PR addresses a compatibility issue between `PyO3 0.18` and `Rust
1.82`, which causes compilation errors due to unexpected `cfg` condition
names.

When compiled with Rust 1.82, the current [PyO3
0.18](https://github.com/PyO3/pyo3/blob/v0.18.3/src/types/mod.rs#L193

)
dependency produces errors like:
```
error: unexpected `cfg` condition name: `addr_of`
```

This error occurs because of changes in how Rust handles conditional
configuration checks in newer compiler versions.

## Changes in this PR:

1. Upgrades PyO3 from version 0.18 to 0.20 (the latest version
compatible with pyo3-asyncio)
2. Updates generated functions to follow the required API pattern where
required fields precede optional ones
3. API changes in pyo3, whereby in the new version dictionary.get_key
returns a Result<Option<PyAny>> instead of Result<PyAny> in the older
version.

This upgrade ensures compatibility with Rust 1.82 while maintaining all
existing functionality.

## Testing

1. Current protocol tests pass
2. An additional test has been included for required versus optional
parameters.

---------

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 7d64b2f9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -15,9 +15,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
 * For a dependency that is used in the client, or in both the client and the server, use [CargoDependency] directly.
 */
object PythonServerCargoDependency {
    val PyO3: CargoDependency = CargoDependency("pyo3", CratesIo("0.18"))
    val PyO3: CargoDependency = CargoDependency("pyo3", CratesIo("0.20"))
    val PyO3Asyncio: CargoDependency =
        CargoDependency("pyo3-asyncio", CratesIo("0.18"), features = setOf("attributes", "tokio-runtime"))
        CargoDependency("pyo3-asyncio", CratesIo("0.20"), features = setOf("attributes", "tokio-runtime"))
    val Tokio: CargoDependency = CargoDependency("tokio", CratesIo("1.20.1"), features = setOf("full"))
    val TokioStream: CargoDependency = CargoDependency("tokio-stream", CratesIo("0.1.12"))
    val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1"))
+1 −1
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ class PythonServerModuleGenerator(
            rustBlockTemplate(
                """
                ##[#{pyo3}::pymodule]
                ##[#{pyo3}(name = "$libName")]
                ##[pyo3(name = "$libName")]
                pub fn python_library(py: #{pyo3}::Python<'_>, m: &#{pyo3}::types::PyModule) -> #{pyo3}::PyResult<()>
                """,
                *codegenScope,
+12 −3
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustInlineTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
@@ -125,9 +126,17 @@ class PythonServerStructureGenerator(
        )
    }

    // Python function parameters require that all required parameters appear before optional ones.
    // This function sorts the member fields to ensure required fields precede optional fields.
    private fun sortedMembers() =
        members.sortedBy { member ->
            val memberSymbol = symbolProvider.toSymbol(member)
            memberSymbol.isOptional()
        }

    private fun renderStructSignatureMembers(): Writable =
        writable {
            forEachMember(members) { _, memberName, memberSymbol ->
            forEachMember(sortedMembers()) { _, memberName, memberSymbol ->
                val memberType = memberSymbol.rustType()
                rust("$memberName: ${memberType.render()},")
            }
@@ -135,14 +144,14 @@ class PythonServerStructureGenerator(

    private fun renderStructBodyMembers(): Writable =
        writable {
            forEachMember(members) { _, memberName, _ ->
            forEachMember(sortedMembers()) { _, memberName, _ ->
                rust("$memberName,")
            }
        }

    private fun renderConstructorSignature(): Writable =
        writable {
            forEachMember(members) { member, memberName, memberSymbol ->
            forEachMember(sortedMembers()) { member, memberName, memberSymbol ->
                val memberType = memberPythonType(member, memberSymbol)
                rust("/// :param $memberName ${memberType.renderAsDocstring()}:")
            }
+226 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.server.python.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.python.smithy.testutil.cargoTest
import software.amazon.smithy.rust.codegen.server.python.smithy.testutil.executePythonServerCodegenVisitor
import software.amazon.smithy.rust.codegen.server.python.smithy.testutil.generatePythonServerPluginContext
import kotlin.io.path.appendText

internal class PythonServerRequiredPrecedeOptionalTest {
    @Test
    fun `mandatory fields are reordered to be before optional`() {
        val model =
            """
            namespace test

            use aws.protocols#restJson1
            use smithy.framework#ValidationException

            @restJson1
            service SampleService {
                operations: [
                    OpWithIncorrectOrder, OpWithCorrectOrder, OpWithDefaults
                ],
            }

            @http(method: "POST", uri: "/opIncorrect")
            operation OpWithIncorrectOrder {
                input:= {
                    a: String
                    @required
                    b: String
                    c: String
                    @required
                    d: String
                }
                output:= {
                    a: String
                    @required
                    b: String
                    c: String
                    @required
                    d: String
                }
                errors: [ValidationException]
            }

            @http(method: "POST", uri: "/opCorrect")
            operation OpWithCorrectOrder {
                input:= {
                    @required
                    b: String
                    @required
                    d: String
                    a: String
                    c: String
                }
                output:= {
                    @required
                    b: String
                    @required
                    d: String
                    a: String
                    c: String
                }
                errors: [ValidationException]
            }

            @http(method: "POST", uri: "/opWithDefaults")
            operation OpWithDefaults {
                input:= {
                    a: String,
                    b: String = "hi"
                }
                output:= {
                    a: String,
                    b: String = "hi"
                }
            }
            """.asSmithyModel(smithyVersion = "2")

        val (pluginCtx, testDir) = generatePythonServerPluginContext(model)
        executePythonServerCodegenVisitor(pluginCtx)

        val writer = RustWriter.forModule("service")
        writer.unitTest("test_required_fields") {
            fun createInstanceWithRequiredFieldsOnly(
                module: String,
                typeName: String,
            ) = writable {
                rustTemplate(
                    """
                    py.run(
                        "data = $typeName(\"b\", \"d\")",
                        Some(globals),
                        Some(locals),
                    ).unwrap();

                    // Python should have been able to construct input.
                    let data = locals
                        .get_item("data")
                        .expect("Python exception occurred during dictionary lookup")
                        .unwrap()
                        .extract::<$module::$typeName>()
                        .unwrap();
                    assert_eq!(data.b, "b");
                    assert_eq!(data.d, "d");
                    """,
                )
            }

            fun createInstance(
                module: String,
                typeName: String,
            ) = writable {
                rustTemplate(
                    """
                    py.run(
                        "data = $typeName(\"b\", \"d\", a = \"a\", c = \"c\")",
                        Some(globals),
                        Some(locals),
                    ).unwrap();

                    // Python should have been able to construct input.
                    let data = locals
                        .get_item("data")
                        .expect("Python exception occurred during dictionary lookup")
                        .unwrap()
                        .extract::<$module::$typeName>()
                        .unwrap();
                    assert_eq!(data.b, "b");
                    assert_eq!(data.d, "d");
                    assert_eq!(data.a, Some("a".to_string()));
                    assert_eq!(data.c, Some("c".to_string()));
                    """,
                )
            }

            fun createDefaultInstance(
                module: String,
                typeName: String,
            ) = writable {
                rustTemplate(
                    """
                    // Default values are not exported from Rust. However, they
                    // are marked as non-optional.
                    py.run(
                        "data = $typeName(\"b\", \"a\")",
                        Some(globals),
                        Some(locals),
                    ).unwrap();

                    // Python should have been able to construct input.
                    let data = locals
                        .get_item("data")
                        .expect("Python exception occurred during dictionary lookup")
                        .unwrap()
                        .extract::<$module::$typeName>()
                        .unwrap();
                    assert_eq!(data.a, Some("a".to_string()));
                    assert_eq!(data.b, "b");
                    """,
                )
            }

            rustTemplate(
                """
                use crate::{input, output};
                use #{pyo3}::{types::IntoPyDict, Python};

                pyo3::prepare_freethreaded_python();
                Python::with_gil(|py| {
                    let globals = [
                        ("OpWithIncorrectOrderInput", py.get_type::<input::OpWithIncorrectOrderInput>()),
                        ("OpWithCorrectOrderInput", py.get_type::<input::OpWithCorrectOrderInput>()),
                        ("OpWithDefaultsInput", py.get_type::<input::OpWithDefaultsInput>()),
                        ("OpWithIncorrectOrderOutput", py.get_type::<output::OpWithIncorrectOrderOutput>()),
                        ("OpWithCorrectOrderOutput", py.get_type::<output::OpWithCorrectOrderOutput>()),
                        ("OpWithDefaultsOutput", py.get_type::<output::OpWithDefaultsOutput>())
                        ]
                        .into_py_dict(py);

                    let locals = [("OpWithIncorrectOrderInput", py.get_type::<input::OpWithIncorrectOrderInput>())].into_py_dict(py);

                    #{IncorrectOrderInputRequiredOnly}
                    #{CorrectOrderInputRequiredOnly}
                    #{IncorrectOrderOutputRequiredOnly}
                    #{CorrectOrderOutputRequiredOnly}
                    #{IncorrectOrderInput}
                    #{CorrectOrderInput}
                    #{IncorrectOrderOutput}
                    #{CorrectOrderOutput}
                    #{DefaultsInput}
                    #{DefaultsOutput}
                });
                """,
                "pyo3" to PythonServerCargoDependency.PyO3.toDevDependency().toType(),
                "IncorrectOrderInputRequiredOnly" to createInstanceWithRequiredFieldsOnly("input", "OpWithIncorrectOrderInput"),
                "CorrectOrderInputRequiredOnly" to createInstanceWithRequiredFieldsOnly("input", "OpWithCorrectOrderInput"),
                "IncorrectOrderOutputRequiredOnly" to createInstanceWithRequiredFieldsOnly("output", "OpWithIncorrectOrderOutput"),
                "CorrectOrderOutputRequiredOnly" to createInstanceWithRequiredFieldsOnly("output", "OpWithCorrectOrderOutput"),
                "IncorrectOrderInput" to createInstance("input", "OpWithIncorrectOrderInput"),
                "CorrectOrderInput" to createInstance("input", "OpWithCorrectOrderInput"),
                "IncorrectOrderOutput" to createInstance("output", "OpWithIncorrectOrderOutput"),
                "CorrectOrderOutput" to createInstance("output", "OpWithCorrectOrderOutput"),
                "DefaultsInput" to createDefaultInstance("input", "OpWithDefaultsInput"),
                "DefaultsOutput" to createDefaultInstance("output", "OpWithDefaultsOutput"),
            )
        }

        testDir.resolve("src/service.rs").appendText(writer.toString())

        cargoTest(testDir)
    }
}
+2 −0
Original line number Diff line number Diff line
@@ -119,6 +119,7 @@ internal class PythonServerTypesTest {

                                locals
                                    .get_item("output")
                                    .expect("Python exception occurred during dictionary lookup")
                                    .unwrap()
                                    .extract::<output::EchoOutput>()
                                    .unwrap()
@@ -212,6 +213,7 @@ internal class PythonServerTypesTest {

                            locals
                                .get_item("output")
                                .expect("Python exception occurred during dictionary lookup")
                                .unwrap()
                                .extract::<output::EchoOutput>()
                                .unwrap()
Loading