Unverified Commit 3229089f authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Allow python middleware to set the URI of a request (#3002)



## Motivation and Context
Python middleware were unable to change the URI which makes it
impossible to change the route of an incoming request in Python.

## Description
URI has a #[setter] that Python side can use to change it.

## Testing
- Test has been added that ensures URI can be changed from python.
- If the URI is modified by a middleware in a local smithy-rs Python
server, a different route will be invoked.

## Checklist
- I've updated the changelog.next.toml

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 3c11815e
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -195,3 +195,9 @@ message = "The `futures_core::stream::Stream` trait has been removed from public
references = ["smithy-rs#2978"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"

[[smithy-rs]]
message = "Python middleware can set URI. This can be used to route a request to a different handler."
references = ["smithy-rs#3005"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server" }
author = "drganjoo"
 No newline at end of file
+18 −2
Original line number Diff line number Diff line
@@ -141,7 +141,14 @@ async def middleware(request, next):
    body = bytes(await request.body).decode()
    body_reversed = body[::-1]
    request.body = body_reversed.encode()
    # Add a new header
    request.headers["X-From-Middleware"] = "yes"
    # Change an existing header.
    request.headers["X-Existing"] = "changed"
    # Delete an existing header.
    del request.headers["X-To-Delete"]
    # Change the URI.
    request.uri = "/changed_uri"
    return await next(request)
"#,
    );
@@ -149,7 +156,10 @@ async def middleware(request, next):

    let th = tokio::spawn(async move {
        let (req, send_response) = handle.next_request().await.unwrap();
        assert_eq!(&"yes", req.headers().get("X-From-Middleware").unwrap());
        assert_eq!("yes", req.headers().get("X-From-Middleware").unwrap());
        assert_eq!("changed", req.headers().get("X-Existing").unwrap());
        assert!(req.headers().get("X-To-Delete").is_none());
        assert_eq!("/changed_uri", req.uri());
        let req_body = hyper::body::to_bytes(req.into_body()).await.unwrap();
        assert_eq!(req_body, "hello server".chars().rev().collect::<String>());
        send_response.send_response(
@@ -159,7 +169,13 @@ async def middleware(request, next):
        );
    });

    let request = simple_request("hello server");
    let mut request = simple_request("hello server");
    assert_ne!(request.uri(), "/changed_uri");
    // Add a header that the middleware should modify.
    let headers_mut = request.headers_mut();
    headers_mut.insert("X-Existing", http::HeaderValue::from_static("yes"));
    // Add a header that the middleware should remove.
    headers_mut.insert("X-To-Delete", http::HeaderValue::from_static("delete-this"));
    let response = service.call(request);
    assert_body(response.await?, "hello client").await;

+49 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
use aws_smithy_http_server_python::PyRequest;
use http::{Request, Version};
use hyper::Body;
use pyo3::{prelude::*, py_run};
use pyo3::{exceptions::PyValueError, prelude::*, py_run};

#[pyo3_asyncio::tokio::test]
async fn accessing_request_properties() -> PyResult<()> {
@@ -71,3 +71,51 @@ async def handler(req):

    Ok(())
}

#[pyo3_asyncio::tokio::test]
async fn accessing_and_changing_request_uri() -> PyResult<()> {
    let request = Request::builder()
        .uri("/op1")
        .body(Body::from("hello world"))
        .expect("could not build request");
    let py_request = PyRequest::new(request);

    // Call an async Python method to change the URI and return it.
    let modified_req = Python::with_gil(|py| {
        let module = PyModule::from_code(
            py,
            r#"
async def handler(req):
    assert req.uri == "/op1"
    # add a trailing slash to the uri
    req.uri = "/op1/"
    assert req.uri == "/op1/"
    return req
"#,
            "",
            "",
        )?;

        let req_ref = PyCell::new(py, py_request)?;
        let handler = module.getattr("handler")?;
        let output = handler.call1((req_ref,))?;

        Ok::<_, PyErr>(pyo3_asyncio::tokio::into_future(output))
    })??
    .await?;

    // Confirm that the URI has been changed when the modified PyRequest instance
    // from Python is converted into a http::Request<> instance.
    Python::with_gil(|py| {
        let request_cell: &PyCell<PyRequest> = modified_req.downcast(py)?;
        let mut request = request_cell.borrow_mut();
        let http_request = request
            .take_inner()
            .ok_or_else(|| PyValueError::new_err("inner http request has already been consumed"))?;
        assert_eq!(http_request.uri(), "/op1/");

        Ok::<_, PyErr>(())
    })?;

    Ok(())
}
+23 −1
Original line number Diff line number Diff line
@@ -10,7 +10,10 @@ use std::sync::Arc;

use aws_smithy_http_server::body::Body;
use http::{request::Parts, Request};
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use pyo3::{
    exceptions::{PyRuntimeError, PyValueError},
    prelude::*,
};
use tokio::sync::Mutex;

use super::{PyHeaderMap, PyMiddlewareError};
@@ -76,6 +79,25 @@ impl PyRequest {
            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
    }

    /// Sets the URI of this request.
    ///
    /// :type str:
    #[setter]
    fn set_uri(&mut self, uri_str: String) -> PyResult<()> {
        self.parts.as_mut().map_or_else(
            || Err(PyMiddlewareError::RequestGone.into()),
            |parts| {
                parts.uri = uri_str.parse().map_err(|e: http::uri::InvalidUri| {
                    PyValueError::new_err(format!(
                        "URI `{}` cannot be parsed. Error: {}",
                        uri_str, e
                    ))
                })?;
                Ok(())
            },
        )
    }

    /// Return the HTTP version of this request.
    ///
    /// :type str: