diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index e26bec0a15ff5fd03783982ac324e05a9e54c37b..96da64f39caadb575fcf08f57958c0db2b22736b 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -678,6 +678,8 @@ extern "C" { // These functions are defined in OpenSSL as macros, so we shim them #[link_name = "BIO_eof_shim"] pub fn BIO_eof(b: *mut BIO) -> c_int; + #[link_name = "BIO_set_nbio_shim"] + pub fn BIO_set_nbio(b: *mut BIO, enabled: c_long) -> c_long; #[link_name = "BIO_set_mem_eof_return_shim"] pub fn BIO_set_mem_eof_return(b: *mut BIO, v: c_int); pub fn SSL_CTX_set_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long; diff --git a/openssl-sys/src/openssl_shim.c b/openssl-sys/src/openssl_shim.c index 8ebe23ac2c0390db58dab5fe015ae8e0d62ea61d..84adb47beb81b95df069dd13bd8e0e55f268ab16 100644 --- a/openssl-sys/src/openssl_shim.c +++ b/openssl-sys/src/openssl_shim.c @@ -83,6 +83,10 @@ int BIO_eof_shim(BIO *b) { return BIO_eof(b); } +long BIO_set_nbio_shim(BIO *b, long enabled) { + return BIO_set_nbio(b, enabled); +} + void BIO_set_mem_eof_return_shim(BIO *b, int v) { BIO_set_mem_eof_return(b, v); } diff --git a/openssl/Cargo.toml b/openssl/Cargo.toml index 6607ef949b2d38d609a93e5e48e55e36c4d51675..cd9f97c3ce5e75a2b0f4512a44711b5751aa3d70 100644 --- a/openssl/Cargo.toml +++ b/openssl/Cargo.toml @@ -34,3 +34,4 @@ libc = "0.1" [dev-dependencies] rustc-serialize = "0.3" net2 = "0.2.13" +nix = "0.4" diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index 9ff6cae97ef8596a52f3217f851e28b8ad8b0260..0126b2774322459fa416852d1a73c4c843718448 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -17,7 +17,20 @@ pub enum SslError { /// The SSL session has been closed by the other end SslSessionClosed, /// An error in the OpenSSL library - OpenSslErrors(Vec) + OpenSslErrors(Vec), +} + +/// An error on a nonblocking stream. +#[derive(Debug)] +pub enum NonblockingSslError { + /// A standard SSL error occurred. + SslError(SslError), + /// The OpenSSL library wants data from the remote socket; + /// the caller should wait for read readiness. + WantRead, + /// The OpenSSL library wants to send data to the remote socket; + /// the caller should wait for write readiness. + WantWrite, } impl fmt::Display for SslError { @@ -59,6 +72,35 @@ impl error::Error for SslError { } } +impl fmt::Display for NonblockingSslError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str(error::Error::description(self)) + } +} + +impl error::Error for NonblockingSslError { + fn description(&self) -> &str { + match *self { + NonblockingSslError::SslError(ref e) => e.description(), + NonblockingSslError::WantRead => "The OpenSSL library wants data from the remote socket", + NonblockingSslError::WantWrite => "The OpenSSL library want to send data to the remote socket", + } + } + + fn cause(&self) -> Option<&error::Error> { + match *self { + NonblockingSslError::SslError(ref e) => e.cause(), + _ => None + } + } +} + +impl From for NonblockingSslError { + fn from(e: SslError) -> NonblockingSslError { + NonblockingSslError::SslError(e) + } +} + /// An error from the OpenSSL library #[derive(Debug, Clone, PartialEq, Eq)] pub enum OpensslError { diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index e76529a5e942ad37584d29b603046adaac9ea310..620800560c1463df81dcda00a6a3883fcc3b5915 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -22,7 +22,7 @@ use std::slice; use bio::{MemBio}; use ffi; use dh::DH; -use ssl::error::{SslError, SslSessionClosed, StreamError, OpenSslErrors}; +use ssl::error::{NonblockingSslError, SslError, SslSessionClosed, StreamError, OpenSslErrors}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -1465,3 +1465,232 @@ impl MaybeSslStream where S: Read+Write { } } } + +/// An SSL stream wrapping a nonblocking socket. +#[derive(Clone)] +pub struct NonblockingSslStream { + stream: S, + ssl: Arc, +} + +impl NonblockingSslStream { + pub fn try_clone(&self) -> io::Result> { + Ok(NonblockingSslStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + }) + } +} + +impl NonblockingSslStream { + fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result, SslError> { + unsafe { + let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); + ffi::BIO_set_nbio(bio, 1); + ffi::SSL_set_bio(ssl.ssl, bio, bio); + } + + Ok(NonblockingSslStream { + stream: stream, + ssl: Arc::new(ssl), + }) + } + + fn make_error(&self, ret: c_int) -> NonblockingSslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + let ssl_error = if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + SslError::StreamError(io::Error::last_os_error()) + } + } else { + err + }; + ssl_error.into() + }, + LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite, + LibSslError::ErrorWantRead => NonblockingSslError::WantRead, + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } + + /// Returns a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Returns a mutable reference to the underlying stream. + /// + /// ## Warning + /// + /// It is inadvisable to read from or write to the underlying stream as it + /// will most likely corrupt the SSL session. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + + /// Returns a reference to the Ssl. + pub fn ssl(&self) -> &Ssl { + &self.ssl + } +} + +#[cfg(unix)] +impl NonblockingSslStream { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +#[cfg(unix)] +impl ::std::os::unix::io::AsRawFd for NonblockingSslStream { + fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { + self.stream.as_raw_fd() + } +} + +#[cfg(windows)] +impl NonblockingSslStream { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_socket() as c_int; + let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + // WantRead/WantWrite is okay here; we'll finish the handshake in + // subsequent send/recv calls. + match ssl.make_error(ret) { + NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), + NonblockingSslError::SslError(other) => Err(other), + } + } + } +} + +impl NonblockingSslStream { + /// Read bytes from the SSL stream into `buf`. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `read()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your read. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// On a return value of `Ok(count)`, count is the number of decrypted + /// plaintext bytes copied into the `buf` slice. + pub fn read(&mut self, buf: &mut [u8]) -> Result { + let ret = self.ssl.read(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + + /// Write bytes from `buf` to the SSL stream. + /// + /// Given the SSL state machine, this method may return either `WantWrite` + /// or `WantRead` to indicate that your event loop should respectively wait + /// for write or read readiness on the underlying stream. Upon readiness, + /// repeat your `write()` call with the same arguments each time until you + /// receive an `Ok(count)`. + /// + /// An `SslError` return value, is terminal; do not re-attempt your write. + /// + /// As expected of a nonblocking API, this method will never block your + /// thread on I/O. + /// + /// Given a return value of `Ok(count)`, count is the number of plaintext bytes + /// from the `buf` slice that were encrypted and written onto the stream. + pub fn write(&mut self, buf: &[u8]) -> Result { + let ret = self.ssl.write(buf); + if ret > 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } +} diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs index 033a3b8648722ee33c6128384704f8408f1d3f4b..8335bc5340f62f62ffe7fdb88b166121adfac9e4 100644 --- a/openssl/src/ssl/tests.rs +++ b/openssl/src/ssl/tests.rs @@ -819,3 +819,135 @@ fn test_sslv2_connect_failure() { let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]); SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap(); } + +#[cfg(target_os = "linux")] +mod nonblocking_tests { + extern crate nix; + + use std::io::Write; + use std::net::TcpStream; + use std::os::unix::io::AsRawFd; + + use super::Server; + use self::nix::sys::epoll; + use self::nix::fcntl; + use ssl; + use ssl::error::NonblockingSslError; + use ssl::SslMethod; + use ssl::SslMethod::Sslv23; + use ssl::{SslContext, NonblockingSslStream}; + + fn wait_io(stream: &NonblockingSslStream, read: bool, timeout_ms: isize) -> bool { + let fd = stream.as_raw_fd(); + let ep = epoll::epoll_create().unwrap(); + let event = if read { + epoll::EpollEvent { + events: epoll::EPOLLIN | epoll::EPOLLERR, + data: 0, + } + } else { + epoll::EpollEvent { + events: epoll::EPOLLOUT, + data: 0, + } + }; + epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap(); + let mut events = [event]; + let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap(); + epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap(); + assert!(count <= 1); + count == 1 + } + + fn make_nonblocking(stream: &TcpStream) { + let fd = stream.as_raw_fd(); + fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap(); + } + + #[test] + fn test_write_nonblocking() { + let (_s, stream) = Server::new(); + make_nonblocking(&stream); + let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of openssl. + panic!("Too many read/write round trips in handshake!!"); + } + let result = stream.write("hello".as_bytes()); + match result { + Ok(_) => { + break; + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + } + } + + // Second write should succeed immediately--plenty of space in kernel buffer, + // and handshake just completed. + stream.write(" there".as_bytes()).unwrap(); + } + + #[test] + fn test_read_nonblocking() { + let (_s, stream) = Server::new(); + make_nonblocking(&stream); + let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of openssl. + panic!("Too many read/write round trips in handshake!!"); + } + let result = stream.write("GET /\r\n\r\n".as_bytes()); + match result { + Ok(n) => { + assert_eq!(n, 9); + break; + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + } + } + let mut input_buffer = [0u8; 1500]; + let result = stream.read(&mut input_buffer); + let bytes_read = match result { + Ok(n) => { + // This branch is unlikely, but on an overloaded VM with + // unlucky context switching, the response could actually + // be in the receive buffer before we issue the read() syscall... + n + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 3000)); + // Second read should return application data. + stream.read(&mut input_buffer).unwrap() + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + }; + assert!(bytes_read >= 5); + assert_eq!(&input_buffer[..5], b"HTTP/"); + } +}