Commit c37767df authored by Jamie Turner's avatar Jamie Turner
Browse files

Nonblocking streams support.

parent 549f2573
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -678,6 +678,8 @@ extern "C" {
    // These functions are defined in OpenSSL as macros, so we shim them
    #[link_name = "BIO_eof_shim"]
    pub fn BIO_eof(b: *mut BIO) -> c_int;
    #[link_name = "BIO_set_nbio_shim"]
    pub fn BIO_set_nbio(b: *mut BIO, enabled: c_long) -> c_long;
    #[link_name = "BIO_set_mem_eof_return_shim"]
    pub fn BIO_set_mem_eof_return(b: *mut BIO, v: c_int);
    pub fn SSL_CTX_set_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long;
+4 −0
Original line number Diff line number Diff line
@@ -83,6 +83,10 @@ int BIO_eof_shim(BIO *b) {
    return BIO_eof(b);
}

long BIO_set_nbio_shim(BIO *b, long enabled) {
    return BIO_set_nbio(b, enabled);
}

void BIO_set_mem_eof_return_shim(BIO *b, int v) {
    BIO_set_mem_eof_return(b, v);
}
+1 −0
Original line number Diff line number Diff line
@@ -34,3 +34,4 @@ libc = "0.1"
[dev-dependencies]
rustc-serialize = "0.3"
net2 = "0.2.13"
nix = "0.4"
+43 −1
Original line number Diff line number Diff line
@@ -17,7 +17,20 @@ pub enum SslError {
    /// The SSL session has been closed by the other end
    SslSessionClosed,
    /// An error in the OpenSSL library
    OpenSslErrors(Vec<OpensslError>)
    OpenSslErrors(Vec<OpensslError>),
}

/// An error on a nonblocking stream.
#[derive(Debug)]
pub enum NonblockingSslError {
    /// A standard SSL error occurred.
    SslError(SslError),
    /// The OpenSSL library wants data from the remote socket;
    /// the caller should wait for read readiness.
    WantRead,
    /// The OpenSSL library wants to send data to the remote socket;
    /// the caller should wait for write readiness.
    WantWrite,
}

impl fmt::Display for SslError {
@@ -59,6 +72,35 @@ impl error::Error for SslError {
    }
}

impl fmt::Display for NonblockingSslError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.write_str(error::Error::description(self))
    }
}

impl error::Error for NonblockingSslError {
    fn description(&self) -> &str {
        match *self {
            NonblockingSslError::SslError(ref e) => e.description(),
            NonblockingSslError::WantRead => "The OpenSSL library wants data from the remote socket",
            NonblockingSslError::WantWrite => "The OpenSSL library want to send data to the remote socket",
        }
    }

    fn cause(&self) -> Option<&error::Error> {
        match *self {
            NonblockingSslError::SslError(ref e) => e.cause(),
            _ => None
        }
    }
}

impl From<SslError> for NonblockingSslError {
    fn from(e: SslError) -> NonblockingSslError {
        NonblockingSslError::SslError(e)
    }
}

/// An error from the OpenSSL library
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OpensslError {
+230 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ use std::slice;
use bio::{MemBio};
use ffi;
use dh::DH;
use ssl::error::{SslError, SslSessionClosed, StreamError, OpenSslErrors};
use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors};
use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;

@@ -1465,3 +1465,232 @@ impl<S> MaybeSslStream<S> where S: Read+Write {
        }
    }
}

/// An SSL stream wrapping a nonblocking socket.
#[derive(Clone)]
pub struct NonblockingSslStream<S> {
    stream: S,
    ssl: Arc<Ssl>,
}

impl NonblockingSslStream<net::TcpStream> {
    pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
        Ok(NonblockingSslStream {
            stream: try!(self.stream.try_clone()),
            ssl: self.ssl.clone(),
        })
    }
}

impl<S> NonblockingSslStream<S> {
    fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> {
        unsafe {
            let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0));
            ffi::BIO_set_nbio(bio, 1);
            ffi::SSL_set_bio(ssl.ssl, bio, bio);
        }

        Ok(NonblockingSslStream {
            stream: stream,
            ssl: Arc::new(ssl),
        })
    }

    fn make_error(&self, ret: c_int) -> NonblockingSslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()),
            LibSslError::ErrorSyscall => {
                let err = SslError::get();
                let count = match err {
                    SslError::OpenSslErrors(ref v) => v.len(),
                    _ => unreachable!(),
                };
                let ssl_error = if count == 0 {
                    if ret == 0 {
                        SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                             "unexpected EOF observed"))
                    } else {
                        SslError::StreamError(io::Error::last_os_error())
                    }
                } else {
                    err
                };
                ssl_error.into()
            },
            LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite,
            LibSslError::ErrorWantRead => NonblockingSslError::WantRead,
            err => panic!("unexpected error {:?} with ret {}", err, ret),
        }
    }

    /// Returns a reference to the underlying stream.
    pub fn get_ref(&self) -> &S {
        &self.stream
    }

    /// Returns a mutable reference to the underlying stream.
    ///
    /// ## Warning
    ///
    /// It is inadvisable to read from or write to the underlying stream as it
    /// will most likely corrupt the SSL session.
    pub fn get_mut(&mut self) -> &mut S {
        &mut self.stream
    }

    /// Returns a reference to the Ssl.
    pub fn ssl(&self) -> &Ssl {
        &self.ssl
    }
}

#[cfg(unix)]
impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
    /// Create a new nonblocking client ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.connect();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
    }

    /// Create a new nonblocking server ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.accept();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
    }
}

#[cfg(unix)]
impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> {
    fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd {
        self.stream.as_raw_fd()
    }
}

#[cfg(windows)]
impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> {
    /// Create a new nonblocking client ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_socket() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.connect();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
    }

    /// Create a new nonblocking server ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_socket() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.accept();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
    }
}

impl<S: Read+Write> NonblockingSslStream<S> {
    /// Read bytes from the SSL stream into `buf`.
    ///
    /// Given the SSL state machine, this method may return either `WantWrite`
    /// or `WantRead` to indicate that your event loop should respectively wait
    /// for write or read readiness on the underlying stream.  Upon readiness,
    /// repeat your `read()` call with the same arguments each time until you
    /// receive an `Ok(count)`.
    ///
    /// An `SslError` return value, is terminal; do not re-attempt your read.
    ///
    /// As expected of a nonblocking API, this method will never block your
    /// thread on I/O.
    ///
    /// On a return value of `Ok(count)`, count is the number of decrypted
    /// plaintext bytes copied into the `buf` slice.
    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        }
    }

    /// Write bytes from `buf` to the SSL stream.
    ///
    /// Given the SSL state machine, this method may return either `WantWrite`
    /// or `WantRead` to indicate that your event loop should respectively wait
    /// for write or read readiness on the underlying stream.  Upon readiness,
    /// repeat your `write()` call with the same arguments each time until you
    /// receive an `Ok(count)`.
    ///
    /// An `SslError` return value, is terminal; do not re-attempt your write.
    ///
    /// As expected of a nonblocking API, this method will never block your
    /// thread on I/O.
    ///
    /// Given a return value of `Ok(count)`, count is the number of plaintext bytes
    /// from the `buf` slice that were encrypted and written onto the stream.
    pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
        let ret = self.ssl.write(buf);
        if ret > 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        }
    }
}
Loading