Unverified Commit e5c8cf30 authored by Matteo Bigoi's avatar Matteo Bigoi Committed by GitHub
Browse files

[Python] Allow to run pure Python request middlewares inside a Tower service (#1734)

## Motivation and Context
* Customers want to be able to implement simple middleware directly in Python. This PR aims to add the initial support for it.
* Improve the idiomatic experience of logging by exposing a handler compatible with Python's standard library `logging` module.

## Description
### Middleware
A middleware is defined as a sync or async Python function that can return multiple values, following these rules:

* Middleware not returning will let the execution continue without changing the original request.
* Middleware returning a modified Request will update the original request before continuing the execution.
* Middleware returning a Response will immediately terminate the request handling and return the response constructed from Python.
* Middleware raising MiddlewareException will immediately terminate the request handling and return a protocol specific error, with the option of setting the HTTP return code.
* Middleware raising any other exception will immediately terminate the request handling and return a protocol specific error, with HTTP status code 500.

Middlewares are registered into the Python application and executed in order of registration.

Example:

from sdk import App
from sdk.middleware import Request, MiddlewareException

app = App()

@app.request_middleware
def inject_header(request: Request):
    request.set_header("x-amzn-answer", "42")
    return request

@app.request_middleare
def check_header(request: Request):
    if request.get_header("x-amzn-answer") != "42":
        raise MiddlewareException("Wrong answer", 404)

@app.request_middleware
def dump_headers(request: Request):
    logging.debug(f"Request headers after middlewares: {request.headers()}")

**NOTE: this PR only adds support for request middlewares, which are executed before the operation handler. Response middlewares, executed after the operation are tracked here: https://github.com/awslabs/smithy-rs/issues/1754.**



### Logging
To improve the idiomatic experience, now logging need to be configured from the Python side by using the standard `logging` module. This allows customers to opt-out of our `tracing` based logging implementation and use their own and logging level is now driven by Python.

import logging
from sdk.logging import TracingHandler

logging.basicConfig(level=logging.INFO, handlers=[TracingHandler.handle()])

Signed-off-by: default avatarBigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: default avatarBurak <burakvar@amazon.co.uk>
parent 997beeb4
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -10,9 +10,14 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[rust-runtime]]
message = "Pokémon Service example code now runs clippy during build."
references = ["smithy-rs#1727"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" }
author = "GeneralSwiss"

[[smithy-rs]]
message = "Implement support for pure Python request middleware. Improve idiomatic logging support over tracing."
references = ["smithy-rs#1734"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" }
author = "crisidev"
+43 −22
Original line number Diff line number Diff line
@@ -30,7 +30,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
 * Example:
 *     from pool import DatabasePool
 *     from my_library import App, OperationInput, OperationOutput

 *     @dataclass
 *     class Context:
 *         db = DatabasePool()
@@ -69,6 +68,7 @@ class PythonApplicationGenerator(
    private val libName = "lib${coreCodegenContext.settings.moduleName.toSnakeCase()}"
    private val runtimeConfig = coreCodegenContext.runtimeConfig
    private val model = coreCodegenContext.model
    private val protocol = coreCodegenContext.protocol
    private val codegenScope =
        arrayOf(
            "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
@@ -101,6 +101,7 @@ class PythonApplicationGenerator(
            ##[derive(Debug, Default)]
            pub struct App {
                handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
                middlewares: #{SmithyPython}::PyMiddlewares,
                context: Option<#{pyo3}::PyObject>,
                workers: #{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>>,
            }
@@ -116,6 +117,7 @@ class PythonApplicationGenerator(
                fn clone(&self) -> Self {
                    Self {
                        handlers: self.handlers.clone(),
                        middlewares: self.middlewares.clone(),
                        context: self.context.clone(),
                        workers: #{parking_lot}::Mutex::new(vec![]),
                    }
@@ -151,7 +153,7 @@ class PythonApplicationGenerator(
                    val name = operationName.toSnakeCase()
                    rustTemplate(
                        """
                        let ${name}_locals = pyo3_asyncio::TaskLocals::new(event_loop);
                        let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
                        let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone();
                        let router = router.$name(move |input, state| {
                            #{pyo3_asyncio}::tokio::scope(${name}_locals, crate::operation_handler::$name(input, state, handler))
@@ -162,11 +164,20 @@ class PythonApplicationGenerator(
                }
                rustTemplate(
                    """
                    let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop);
                    use #{SmithyPython}::PyApp;
                    let service = #{tower}::ServiceBuilder::new().layer(
                        #{SmithyPython}::PyMiddlewareLayer::new(
                            self.middlewares.clone(),
                            self.protocol(),
                            middleware_locals
                        )?,
                    );
                    let router: #{SmithyServer}::routing::Router = router
                        .build()
                        .expect("Unable to build operation registry")
                        .into();
                    Ok(router)
                    Ok(router.layer(service))
                    """,
                    *codegenScope,
                )
@@ -175,20 +186,25 @@ class PythonApplicationGenerator(
    }

    private fun renderPyAppTrait(writer: RustWriter) {
        val protocol = protocol.toString().replace("#", "##")
        writer.rustTemplate(
            """
            impl #{SmithyPython}::PyApp for App {
                fn workers(&self) -> &#{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>> {
                    &self.workers
                }

                fn context(&self) -> &Option<#{pyo3}::PyObject> {
                    &self.context
                }

                fn handlers(&mut self) -> &mut #{HashMap}<String, #{SmithyPython}::PyHandler> {
                    &mut self.handlers
                }
                fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
                    &mut self.middlewares
                }
                fn protocol(&self) -> &'static str {
                    "$protocol"
                }
            }
            """,
            *codegenScope,
@@ -207,16 +223,20 @@ class PythonApplicationGenerator(
                """
                /// Create a new [App].
                ##[new]
                pub fn new(py: #{pyo3}::Python, log_level: Option<#{SmithyPython}::LogLevel>) -> #{pyo3}::PyResult<Self> {
                    let log_level = log_level.unwrap_or(#{SmithyPython}::LogLevel::Info);
                    #{SmithyPython}::logging::setup(py, log_level)?;
                    Ok(Self::default())
                pub fn new() -> Self {
                    Self::default()
                }
                /// Register a context object that will be shared between handlers.
                ##[pyo3(text_signature = "(${'$'}self, context)")]
                pub fn context(&mut self, context: #{pyo3}::PyObject) {
                   self.context = Some(context);
                }
                /// Register a request middleware function that will be run inside a Tower layer, without cloning the body.
                ##[pyo3(text_signature = "(${'$'}self, func)")]
                pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
                    use #{SmithyPython}::PyApp;
                    self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request)
                }
                /// Main entrypoint: start the server on multiple workers.
                ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
                pub fn run(
@@ -235,7 +255,7 @@ class PythonApplicationGenerator(
                pub fn start_worker(
                    &mut self,
                    py: pyo3::Python,
                    socket: &pyo3::PyCell<aws_smithy_http_server_python::PySocket>,
                    socket: &pyo3::PyCell<#{SmithyPython}::PySocket>,
                    worker_number: isize,
                ) -> pyo3::PyResult<()> {
                    use #{SmithyPython}::PyApp;
@@ -280,21 +300,17 @@ class PythonApplicationGenerator(
            """.trimIndent(),
        )
        writer.rust(
            if (operations.any { it.errors.isNotEmpty() }) {
            """
            /// from $libName import ${Inputs.namespace}
            /// from $libName import ${Outputs.namespace}
                /// from $libName import ${Errors.namespace}
                """.trimIndent()
            } else {
                """
                /// from $libName import ${Inputs.namespace}
                /// from $libName import ${Outputs.namespace}
                """.trimIndent()
            },
            """.trimIndent(),
        )
        if (operations.any { it.errors.isNotEmpty() }) {
            writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent())
        }
        writer.rust(
            """
            /// from $libName import middleware
            /// from $libName import App
            ///
            /// @dataclass
@@ -304,6 +320,11 @@ class PythonApplicationGenerator(
            /// app = App()
            /// app.context(Context())
            ///
            /// @app.request_middleware
            /// def request_middleware(request: middleware::Request):
            ///     if request.get_header("x-amzn-id") != "secret":
            ///         raise middleware.MiddlewareException("Unsupported `x-amz-id` header", 401)
            ///
            """.trimIndent(),
        )
        writer.operationImplementationStubs(operations)
+39 −0
Original line number Diff line number Diff line
@@ -47,6 +47,8 @@ class PythonServerModuleGenerator(
                renderPyCodegeneratedTypes()
                renderPyWrapperTypes()
                renderPySocketType()
                renderPyLogging()
                renderPyMiddlewareTypes()
                renderPyApplicationType()
            }
        }
@@ -125,6 +127,43 @@ class PythonServerModuleGenerator(
        )
    }

    // Render Python shared socket type.
    private fun RustWriter.renderPyLogging() {
        rustTemplate(
            """
            let logging = #{pyo3}::types::PyModule::new(py, "logging")?;
            logging.add_function(#{pyo3}::wrap_pyfunction!(#{SmithyPython}::py_tracing_event, m)?)?;
            logging.add_class::<#{SmithyPython}::PyTracingHandler>()?;
            #{pyo3}::py_run!(
                py,
                logging,
                "import sys; sys.modules['$libName.logging'] = logging"
            );
            m.add_submodule(logging)?;
            """,
            *codegenScope,
        )
    }

    private fun RustWriter.renderPyMiddlewareTypes() {
        rustTemplate(
            """
            let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?;
            middleware.add_class::<#{SmithyPython}::PyRequest>()?;
            middleware.add_class::<#{SmithyPython}::PyResponse>()?;
            middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?;
            middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?;
            pyo3::py_run!(
                py,
                middleware,
                "import sys; sys.modules['$libName.middleware'] = middleware"
            );
            m.add_submodule(middleware)?;
            """,
            *codegenScope,
        )
    }

    // Render Python application type.
    private fun RustWriter.renderPyApplicationType() {
        rustTemplate(
+8 −10
Original line number Diff line number Diff line
@@ -90,7 +90,6 @@ class PythonServerOperationHandlerGenerator(
            rustTemplate(
                """
                #{tracing}::debug!("Executing Python handler function `$name()`");
                #{tokio}::task::block_in_place(move || {
                #{pyo3}::Python::with_gil(|py| {
                    let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
                    let output = if handler.args == 1 {
@@ -100,7 +99,6 @@ class PythonServerOperationHandlerGenerator(
                    };
                    output.extract::<$output>()
                })
                })
                """,
                *codegenScope,
            )
+8 −2
Original line number Diff line number Diff line
@@ -13,14 +13,18 @@ Python server runtime for Smithy Rust Server Framework.
publish = true

[dependencies]
aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-http-server = { path = "../aws-smithy-http-server" }
aws-smithy-json = { path = "../aws-smithy-json" }
aws-smithy-types = { path = "../aws-smithy-types" }
aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-xml = { path = "../aws-smithy-xml" }
bytes = "1.2"
futures = "0.3"
http = "0.2"
hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] }
num_cpus = "1.13.1"
parking_lot = "0.12.1"
pin-project-lite = "0.2"
pyo3 = "0.16.5"
pyo3-asyncio = { version = "0.16.0", features = ["tokio-runtime"] }
signal-hook = { version = "0.3.14", features = ["extended-siginfo"] }
@@ -28,12 +32,14 @@ socket2 = { version = "0.4.4", features = ["all"] }
thiserror = "1.0.32"
tokio = { version = "1.20.1", features = ["full"] }
tokio-stream = "0.1"
tower = "0.4.13"
tower = { version = "0.4.13", features = ["util"] }
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["env-filter"] }
tracing-appender = { version = "0.2.2"}

[dev-dependencies]
pretty_assertions = "1"
futures-util = "0.3"

[package.metadata.docs.rs]
all-features = true
Loading