Unverified Commit 9c585a22 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

[Python] Improve Python stubs generation (#2606)



## Motivation and Context
This PR improves the Python stubs generation.

## Description
The main change is about avoiding to setup a placeholder for the Python
module and use the real module name, which allows to generate correct
docstrings during codegeneration.

We also change the stubs layout on disk, with the main stub entrypoint
called `__init__.pyi` instead of `$module_name.pyi`.

The README from the Rust runtime crate has been moved completely to the
example folder and I run autoformatting and style checks on the Python
example code.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: default avatarBurak <unexge@gmail.com>
parent c805f89b
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -152,7 +152,7 @@ class PythonServerCodegenVisitor(
        rustCrate.useShapeWriter(shape) {
            // Use Python specific structure generator that adds the #[pyclass] attribute
            // and #[pymethods] implementation.
            PythonServerStructureGenerator(model, codegenContext.symbolProvider, this, shape).render()
            PythonServerStructureGenerator(model, codegenContext, this, shape).render()

            shape.getTrait<ErrorTrait>()?.also { errorTrait ->
                ErrorImplGenerator(
@@ -190,7 +190,7 @@ class PythonServerCodegenVisitor(
    override fun unionShape(shape: UnionShape) {
        logger.info("[python-server-codegen] Generating an union shape $shape")
        rustCrate.useShapeWriter(shape) {
            PythonServerUnionGenerator(model, codegenContext.symbolProvider, this, shape, renderUnknownVariant = false).render()
            PythonServerUnionGenerator(model, codegenContext, this, shape, renderUnknownVariant = false).render()
        }

        if (shape.isReachableFromOperationInput() && shape.canReachConstrainedShape(
+21 −21
Original line number Diff line number Diff line
@@ -100,17 +100,14 @@ sealed class PythonType {
        override val namespace = type.namespace
    }

    data class Opaque(override val name: String, val rustNamespace: String? = null) : PythonType() {
        // Since Python doesn't have a something like Rust's `crate::` we are using a custom placeholder here
        // and in our stub generation script we will replace placeholder with the real root module name.
        private val pythonRootModulePlaceholder = "__root_module_name__"
    data class Opaque(override val name: String, val pythonRootModuleName: String, val rustNamespace: String? = null) : PythonType() {

        override val namespace: String? = rustNamespace?.split("::")?.joinToString(".") {
            when (it) {
                "crate" -> pythonRootModulePlaceholder
                "crate" -> pythonRootModuleName
                // In Python, we expose submodules from `aws_smithy_http_server_python`
                // like `types`, `middleware`, `tls` etc. from `__root_module__name`
                "aws_smithy_http_server_python" -> pythonRootModulePlaceholder
                // like `types`, `middleware`, `tls` etc. from Python root module
                "aws_smithy_http_server_python" -> pythonRootModuleName
                else -> it
            }
        }
@@ -120,26 +117,29 @@ sealed class PythonType {
/**
 * Return corresponding [PythonType] for a [RustType].
 */
fun RustType.pythonType(): PythonType =
fun RustType.pythonType(pythonRootModuleName: String): PythonType =
    when (this) {
        is RustType.Unit -> PythonType.None
        is RustType.Bool -> PythonType.Bool
        is RustType.Float -> PythonType.Float
        is RustType.Integer -> PythonType.Int
        is RustType.String -> PythonType.Str
        is RustType.Vec -> PythonType.List(this.member.pythonType())
        is RustType.Slice -> PythonType.List(this.member.pythonType())
        is RustType.HashMap -> PythonType.Dict(this.key.pythonType(), this.member.pythonType())
        is RustType.HashSet -> PythonType.Set(this.member.pythonType())
        is RustType.Reference -> this.member.pythonType()
        is RustType.Option -> PythonType.Optional(this.member.pythonType())
        is RustType.Box -> this.member.pythonType()
        is RustType.Dyn -> this.member.pythonType()
        is RustType.Application -> PythonType.Application(this.type.pythonType(), this.args.map { it.pythonType() })
        is RustType.Opaque -> PythonType.Opaque(this.name, this.namespace)
        // TODO(Constraints): How to handle this?
        // Revisit as part of https://github.com/awslabs/smithy-rs/issues/2114
        is RustType.MaybeConstrained -> this.member.pythonType()
        is RustType.Vec -> PythonType.List(this.member.pythonType(pythonRootModuleName))
        is RustType.Slice -> PythonType.List(this.member.pythonType(pythonRootModuleName))
        is RustType.HashMap -> PythonType.Dict(this.key.pythonType(pythonRootModuleName), this.member.pythonType(pythonRootModuleName))
        is RustType.HashSet -> PythonType.Set(this.member.pythonType(pythonRootModuleName))
        is RustType.Reference -> this.member.pythonType(pythonRootModuleName)
        is RustType.Option -> PythonType.Optional(this.member.pythonType(pythonRootModuleName))
        is RustType.Box -> this.member.pythonType(pythonRootModuleName)
        is RustType.Dyn -> this.member.pythonType(pythonRootModuleName)
        is RustType.Application -> PythonType.Application(
            this.type.pythonType(pythonRootModuleName),
            this.args.map {
                it.pythonType(pythonRootModuleName)
            },
        )
        is RustType.Opaque -> PythonType.Opaque(this.name, pythonRootModuleName, rustNamespace = this.namespace)
        is RustType.MaybeConstrained -> this.member.pythonType(pythonRootModuleName)
    }

/**
+6 −6
Original line number Diff line number Diff line
@@ -250,11 +250,11 @@ class PythonApplicationGenerator(
            """,
            *codegenScope,
        ) {
            val middlewareRequest = PythonType.Opaque("Request", "crate::middleware")
            val middlewareResponse = PythonType.Opaque("Response", "crate::middleware")
            val middlewareRequest = PythonType.Opaque("Request", libName, rustNamespace = "crate::middleware")
            val middlewareResponse = PythonType.Opaque("Response", libName, rustNamespace = "crate::middleware")
            val middlewareNext = PythonType.Callable(listOf(middlewareRequest), PythonType.Awaitable(middlewareResponse))
            val middlewareFunc = PythonType.Callable(listOf(middlewareRequest, middlewareNext), PythonType.Awaitable(middlewareResponse))
            val tlsConfig = PythonType.Opaque("TlsConfig", "crate::tls")
            val tlsConfig = PythonType.Opaque("TlsConfig", libName, rustNamespace = "crate::tls")

            rustTemplate(
                """
@@ -344,9 +344,9 @@ class PythonApplicationGenerator(
                val operationName = symbolProvider.toSymbol(operation).name
                val fnName = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operation).name.toSnakeCase())

                val input = PythonType.Opaque("${operationName}Input", "crate::input")
                val output = PythonType.Opaque("${operationName}Output", "crate::output")
                val context = PythonType.Opaque("Ctx")
                val input = PythonType.Opaque("${operationName}Input", libName, rustNamespace = "crate::input")
                val output = PythonType.Opaque("${operationName}Output", libName, rustNamespace = "crate::output")
                val context = PythonType.Opaque("Ctx", libName)
                val returnType = PythonType.Union(listOf(output, PythonType.Awaitable(output)))
                val handler = PythonType.Union(
                    listOf(
+8 −5
Original line number Diff line number Diff line
@@ -18,16 +18,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustInlineTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonEventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonType
import software.amazon.smithy.rust.codegen.server.python.smithy.pythonType
import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstring
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext

/**
 * To share structures defined in Rust with Python, `pyo3` provides the `PyClass` trait.
@@ -36,11 +37,13 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstrin
 */
class PythonServerStructureGenerator(
    model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val codegenContext: ServerCodegenContext,
    private val writer: RustWriter,
    private val shape: StructureShape,
) : StructureGenerator(model, symbolProvider, writer, shape, emptyList()) {
) : StructureGenerator(model, codegenContext.symbolProvider, writer, shape, emptyList()) {

    private val symbolProvider = codegenContext.symbolProvider
    private val libName = codegenContext.settings.moduleName.toSnakeCase()
    private val pyO3 = PythonServerCargoDependency.PyO3.toType()

    override fun renderStructure() {
@@ -157,9 +160,9 @@ class PythonServerStructureGenerator(
    private fun memberPythonType(shape: MemberShape, symbol: Symbol): PythonType =
        if (shape.isEventStream(model)) {
            val eventStreamSymbol = PythonEventStreamSymbolProvider.parseSymbol(symbol)
            val innerT = eventStreamSymbol.innerT.pythonType()
            val innerT = eventStreamSymbol.innerT.pythonType(libName)
            PythonType.AsyncIterator(innerT)
        } else {
            symbol.rustType().pythonType()
            symbol.rustType().pythonType(libName)
        }
}
+9 −5
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency
import software.amazon.smithy.rust.codegen.server.python.smithy.pythonType
import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstring
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext

/*
 * Generate unions that are compatible with Python by wrapping the Rust implementation into
@@ -34,11 +35,13 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstrin
 */
class PythonServerUnionGenerator(
    model: Model,
    private val symbolProvider: SymbolProvider,
    private val codegenContext: ServerCodegenContext,
    private val writer: RustWriter,
    shape: UnionShape,
    private val renderUnknownVariant: Boolean = true,
) : UnionGenerator(model, symbolProvider, writer, shape, renderUnknownVariant) {
) : UnionGenerator(model, codegenContext.symbolProvider, writer, shape, renderUnknownVariant) {
    private val symbolProvider = codegenContext.symbolProvider
    private val libName = codegenContext.settings.moduleName.toSnakeCase()
    private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
    private val unionSymbol = symbolProvider.toSymbol(shape)

@@ -125,7 +128,7 @@ class PythonServerUnionGenerator(
            }
        } else {
            val memberSymbol = symbolProvider.toSymbol(member)
            val pythonType = memberSymbol.rustType().pythonType()
            val pythonType = memberSymbol.rustType().pythonType(libName)
            val targetType = memberSymbol.rustType()
            Attribute("staticmethod").render(writer)
            writer.rust(
@@ -166,7 +169,7 @@ class PythonServerUnionGenerator(
            }
        } else {
            val memberSymbol = symbolProvider.toSymbol(member)
            val pythonType = memberSymbol.rustType().pythonType()
            val pythonType = memberSymbol.rustType().pythonType(libName)
            val targetSymbol = symbolProvider.toSymbol(model.expectShape(member.target))
            val rustType = memberSymbol.rustType()
            writer.rust(
@@ -181,12 +184,13 @@ class PythonServerUnionGenerator(
                } else {
                    "variant.clone()"
                }
                val errorVariant = memberSymbol.rustType().pythonType(libName).renderAsDocstring()
                rustTemplate(
                    """
                    match self.0.as_$funcNamePart() {
                        Ok(variant) => Ok($variantType),
                        Err(_) => Err(#{pyo3}::exceptions::PyValueError::new_err(
                            r"${unionSymbol.name} variant is not of type ${memberSymbol.rustType().pythonType().renderAsDocstring()}"
                            r"${unionSymbol.name} variant is not of type $errorVariant"
                        )),
                    }
                    """,
Loading