From c895b9f09f25410b6e0046d71d42439f7a56c1be Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 22 Sep 2015 21:34:12 -0700 Subject: [PATCH] Get nonblocking tests working on OSX/Windows --- openssl/Cargo.toml | 3 +- openssl/src/ssl/{tests.rs => tests/mod.rs} | 222 ++++++++++----------- openssl/src/ssl/tests/select.rs | 87 ++++++++ 3 files changed, 190 insertions(+), 122 deletions(-) rename openssl/src/ssl/{tests.rs => tests/mod.rs} (86%) create mode 100644 openssl/src/ssl/tests/select.rs diff --git a/openssl/Cargo.toml b/openssl/Cargo.toml index cd9f97c3c..239492591 100644 --- a/openssl/Cargo.toml +++ b/openssl/Cargo.toml @@ -33,5 +33,4 @@ libc = "0.1" [dev-dependencies] rustc-serialize = "0.3" -net2 = "0.2.13" -nix = "0.4" +net2 = "0.2.16" diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests/mod.rs similarity index 86% rename from openssl/src/ssl/tests.rs rename to openssl/src/ssl/tests/mod.rs index 8335bc534..e34c633f8 100644 --- a/openssl/src/ssl/tests.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -4,17 +4,21 @@ use std::fs::File; use std::io::prelude::*; use std::io::{self, BufReader}; use std::iter; +use std::mem; use std::net::{TcpStream, TcpListener, SocketAddr}; use std::path::Path; use std::process::{Command, Child, Stdio, ChildStdin}; use std::thread; +use net2::TcpStreamExt; + use crypto::hash::Type::{SHA256}; use ssl; -use ssl::SslMethod; -use ssl::SslMethod::Sslv23; -use ssl::{SslContext, SslStream, VerifyCallback}; use ssl::SSL_VERIFY_PEER; +use ssl::SslMethod::Sslv23; +use ssl::SslMethod; +use ssl::error::NonblockingSslError; +use ssl::{SslContext, SslStream, VerifyCallback, NonblockingSslStream}; use x509::X509StoreContext; use x509::X509FileType; use x509::X509; @@ -29,6 +33,8 @@ use ssl::SslMethod::Sslv2; #[cfg(feature="dtlsv1")] use net2::UdpSocketExt; +mod select; + fn next_addr() -> SocketAddr { use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering}; static PORT: AtomicUsize = ATOMIC_USIZE_INIT; @@ -331,7 +337,8 @@ run_test!(verify_trusted_get_error_err, |method, stream| { }); run_test!(verify_callback_data, |method, stream| { - fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext, node_id: &Vec) -> bool { + fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext, + node_id: &Vec) -> bool { let cert = x509_ctx.get_current_cert(); match cert { None => false, @@ -808,7 +815,8 @@ mod dtlsv1 { fn test_read_dtlsv1() { let (_s, stream) = Server::new_dtlsv1(Some("hello")); - let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(), stream).unwrap(); + let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(), + stream).unwrap(); let mut buf = [0u8;100]; assert!(stream.read(&mut buf).is_ok()); } @@ -817,137 +825,111 @@ fn test_read_dtlsv1() { #[cfg(feature = "sslv2")] fn test_sslv2_connect_failure() { let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]); - SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap(); + SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), + tcp).err().unwrap(); } -#[cfg(target_os = "linux")] -mod nonblocking_tests { - extern crate nix; - - use std::io::Write; - use std::net::TcpStream; - use std::os::unix::io::AsRawFd; +fn wait_io(stream: &NonblockingSslStream, + read: bool, + timeout_ms: u32) -> bool { + unsafe { + let mut set: select::fd_set = mem::zeroed(); + select::fd_set(&mut set, stream.get_ref()); - use super::Server; - use self::nix::sys::epoll; - use self::nix::fcntl; - use ssl; - use ssl::error::NonblockingSslError; - use ssl::SslMethod; - use ssl::SslMethod::Sslv23; - use ssl::{SslContext, NonblockingSslStream}; - - fn wait_io(stream: &NonblockingSslStream, read: bool, timeout_ms: isize) -> bool { - let fd = stream.as_raw_fd(); - let ep = epoll::epoll_create().unwrap(); - let event = if read { - epoll::EpollEvent { - events: epoll::EPOLLIN | epoll::EPOLLERR, - data: 0, - } - } else { - epoll::EpollEvent { - events: epoll::EPOLLOUT, - data: 0, - } - }; - epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap(); - let mut events = [event]; - let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap(); - epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap(); - assert!(count <= 1); - count == 1 + let write = if read {0 as *mut _} else {&mut set as *mut _}; + let read = if !read {0 as *mut _} else {&mut set as *mut _}; + select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms) + .unwrap() } +} - fn make_nonblocking(stream: &TcpStream) { - let fd = stream.as_raw_fd(); - fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap(); - } - - #[test] - fn test_write_nonblocking() { - let (_s, stream) = Server::new(); - make_nonblocking(&stream); - let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); - - let mut iterations = 0; - loop { - iterations += 1; - if iterations > 7 { - // Probably a safe assumption for the foreseeable future of openssl. - panic!("Too many read/write round trips in handshake!!"); - } - let result = stream.write("hello".as_bytes()); - match result { - Ok(_) => { - break; - }, - Err(NonblockingSslError::WantRead) => { - assert!(wait_io(&stream, true, 1000)); - }, - Err(NonblockingSslError::WantWrite) => { - assert!(wait_io(&stream, false, 1000)); - }, - Err(other) => { - panic!("Unexpected SSL Error: {:?}", other); - }, - } +#[test] +fn test_write_nonblocking() { + let (_s, stream) = Server::new(); + stream.set_nonblocking(true).unwrap(); + let cx = SslContext::new(Sslv23).unwrap(); + let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of + // openssl. + panic!("Too many read/write round trips in handshake!!"); + } + let result = stream.write(b"hello"); + match result { + Ok(_) => { + break; + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, } - - // Second write should succeed immediately--plenty of space in kernel buffer, - // and handshake just completed. - stream.write(" there".as_bytes()).unwrap(); } - #[test] - fn test_read_nonblocking() { - let (_s, stream) = Server::new(); - make_nonblocking(&stream); - let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + // Second write should succeed immediately--plenty of space in kernel + // buffer, and handshake just completed. + stream.write(" there".as_bytes()).unwrap(); +} - let mut iterations = 0; - loop { - iterations += 1; - if iterations > 7 { - // Probably a safe assumption for the foreseeable future of openssl. - panic!("Too many read/write round trips in handshake!!"); - } - let result = stream.write("GET /\r\n\r\n".as_bytes()); - match result { - Ok(n) => { - assert_eq!(n, 9); - break; - }, - Err(NonblockingSslError::WantRead) => { - assert!(wait_io(&stream, true, 1000)); - }, - Err(NonblockingSslError::WantWrite) => { - assert!(wait_io(&stream, false, 1000)); - }, - Err(other) => { - panic!("Unexpected SSL Error: {:?}", other); - }, - } +#[test] +fn test_read_nonblocking() { + let (_s, stream) = Server::new(); + stream.set_nonblocking(true).unwrap(); + let cx = SslContext::new(Sslv23).unwrap(); + let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap(); + + let mut iterations = 0; + loop { + iterations += 1; + if iterations > 7 { + // Probably a safe assumption for the foreseeable future of + // openssl. + panic!("Too many read/write round trips in handshake!!"); } - let mut input_buffer = [0u8; 1500]; - let result = stream.read(&mut input_buffer); - let bytes_read = match result { + let result = stream.write(b"GET /\r\n\r\n"); + match result { Ok(n) => { - // This branch is unlikely, but on an overloaded VM with - // unlucky context switching, the response could actually - // be in the receive buffer before we issue the read() syscall... - n + assert_eq!(n, 9); + break; }, Err(NonblockingSslError::WantRead) => { - assert!(wait_io(&stream, true, 3000)); - // Second read should return application data. - stream.read(&mut input_buffer).unwrap() + assert!(wait_io(&stream, true, 1000)); + }, + Err(NonblockingSslError::WantWrite) => { + assert!(wait_io(&stream, false, 1000)); }, Err(other) => { panic!("Unexpected SSL Error: {:?}", other); }, - }; - assert!(bytes_read >= 5); - assert_eq!(&input_buffer[..5], b"HTTP/"); + } } + let mut input_buffer = [0u8; 1500]; + let result = stream.read(&mut input_buffer); + let bytes_read = match result { + Ok(n) => { + // This branch is unlikely, but on an overloaded VM with + // unlucky context switching, the response could actually + // be in the receive buffer before we issue the read() syscall... + n + }, + Err(NonblockingSslError::WantRead) => { + assert!(wait_io(&stream, true, 3000)); + // Second read should return application data. + stream.read(&mut input_buffer).unwrap() + }, + Err(other) => { + panic!("Unexpected SSL Error: {:?}", other); + }, + }; + assert!(bytes_read >= 5); + assert_eq!(&input_buffer[..5], b"HTTP/"); } diff --git a/openssl/src/ssl/tests/select.rs b/openssl/src/ssl/tests/select.rs new file mode 100644 index 000000000..fcdf40043 --- /dev/null +++ b/openssl/src/ssl/tests/select.rs @@ -0,0 +1,87 @@ +use libc; +pub use self::imp::*; + +extern "system" { + #[link_name = "select"] + fn raw_select(nfds: libc::c_int, + readfds: *mut fd_set, + writefds: *mut fd_set, + errorfds: *mut fd_set, + timeout: *mut libc::timeval) -> libc::c_int; +} + +#[cfg(unix)] +mod imp { + use std::os::unix::prelude::*; + use std::io; + use libc; + + const FD_SETSIZE: usize = 1024; + + #[repr(C)] + pub struct fd_set { + fds_bits: [u64; FD_SETSIZE / 64] + } + + pub fn fd_set(set: &mut fd_set, f: &F) { + let fd = f.as_raw_fd() as usize; + set.fds_bits[fd / 64] |= 1 << (fd % 64); + } + + pub unsafe fn select(max: &F, + read: *mut fd_set, + write: *mut fd_set, + error: *mut fd_set, + timeout_ms: u32) + -> io::Result { + let mut timeout = libc::timeval { + tv_sec: (timeout_ms / 1000) as libc::time_t, + tv_usec: (timeout_ms % 1000 * 1000) as libc::suseconds_t, + }; + let rc = super::raw_select(max.as_raw_fd() + 1, read, write, error, + &mut timeout); + if rc < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(rc != 0) + } + } +} + +#[cfg(windows)] +mod imp { + use std::os::windows::prelude::*; + use std::io; + use libc::{SOCKET, c_uint, c_long, timeval}; + + const FD_SETSIZE: usize = 64; + + #[repr(C)] + pub struct fd_set { + fd_count: c_uint, + fd_array: [SOCKET; FD_SETSIZE], + } + + pub fn fd_set(set: &mut fd_set, f: &F) { + set.fd_array[set.fd_count as usize] = f.as_raw_socket(); + set.fd_count += 1; + } + + pub unsafe fn select(_max: &F, + read: *mut fd_set, + write: *mut fd_set, + error: *mut fd_set, + timeout_ms: u32) + -> io::Result { + let mut timeout = timeval { + tv_sec: (timeout_ms / 1000) as c_long, + tv_usec: (timeout_ms % 1000 * 1000) as c_long, + }; + let rc = super::raw_select(1, read, write, error, &mut timeout); + if rc < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(rc != 0) + } + } +} -- GitLab