Unverified Commit df077c97 authored by Weihang Lo's avatar Weihang Lo Committed by GitHub
Browse files

Use `tokio::sync::Mutex` to make lock get across yield point (#1677)

parent 778539f4
Loading
Loading
Loading
Loading
+9 −10
Original line number Diff line number Diff line
@@ -6,19 +6,20 @@
//! Python wrapped types from aws-smithy-types and aws-smithy-http.

use std::{
    future::Future,
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use bytes::Bytes;
use parking_lot::Mutex;
use pyo3::{
    exceptions::{PyRuntimeError, PyStopIteration},
    iter::IterNextOutput,
    prelude::*,
    pyclass::IterANextOutput,
};
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

use crate::Error;
@@ -288,22 +289,20 @@ impl futures::stream::Stream for ByteStream {
    type Item = Result<Bytes, aws_smithy_http::byte_stream::Error>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut stream = self.0.lock();
        Pin::new(&mut *stream).poll_next(cx)
        let stream = self.0.lock();
        tokio::pin!(stream);
        match stream.poll(cx) {
            Poll::Ready(mut stream) => Pin::new(&mut *stream).poll_next(cx),
            Poll::Pending => Poll::Pending,
        }

    fn size_hint(&self) -> (usize, Option<usize>) {
        self.0.lock().size_hint()
    }
}

/// Return a new data chunk from the stream.
// TODO(https://github.com/awslabs/smithy-rs/issues/1674) lock held across an await point needs to be removed
#[allow(clippy::await_holding_lock)]
async fn yield_data_chunk(
    body: Arc<Mutex<aws_smithy_http::byte_stream::ByteStream>>,
) -> PyResult<Bytes> {
    let mut stream = body.lock();
    let mut stream = body.lock().await;
    match stream.next().await {
        Some(bytes) => bytes.map_err(|e| PyRuntimeError::new_err(e.to_string())),
        None => Err(PyStopIteration::new_err("stream exhausted")),