Commit c895b9f0 authored by Alex Crichton's avatar Alex Crichton
Browse files

Get nonblocking tests working on OSX/Windows

parent c897837e
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -33,5 +33,4 @@ libc = "0.1"

[dev-dependencies]
rustc-serialize = "0.3"
net2 = "0.2.13"
nix = "0.4"
net2 = "0.2.16"
+102 −120
Original line number Diff line number Diff line
@@ -4,17 +4,21 @@ use std::fs::File;
use std::io::prelude::*;
use std::io::{self, BufReader};
use std::iter;
use std::mem;
use std::net::{TcpStream, TcpListener, SocketAddr};
use std::path::Path;
use std::process::{Command, Child, Stdio, ChildStdin};
use std::thread;

use net2::TcpStreamExt;

use crypto::hash::Type::{SHA256};
use ssl;
use ssl::SslMethod;
use ssl::SslMethod::Sslv23;
use ssl::{SslContext, SslStream, VerifyCallback};
use ssl::SSL_VERIFY_PEER;
use ssl::SslMethod::Sslv23;
use ssl::SslMethod;
use ssl::error::NonblockingSslError;
use ssl::{SslContext, SslStream, VerifyCallback, NonblockingSslStream};
use x509::X509StoreContext;
use x509::X509FileType;
use x509::X509;
@@ -29,6 +33,8 @@ use ssl::SslMethod::Sslv2;
#[cfg(feature="dtlsv1")]
use net2::UdpSocketExt;

mod select;

fn next_addr() -> SocketAddr {
    use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering};
    static PORT: AtomicUsize = ATOMIC_USIZE_INIT;
@@ -331,7 +337,8 @@ run_test!(verify_trusted_get_error_err, |method, stream| {
});

run_test!(verify_callback_data, |method, stream| {
    fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext, node_id: &Vec<u8>) -> bool {
    fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext,
                node_id: &Vec<u8>) -> bool {
        let cert = x509_ctx.get_current_cert();
        match cert {
            None => false,
@@ -808,7 +815,8 @@ mod dtlsv1 {
fn test_read_dtlsv1() {
    let (_s, stream) = Server::new_dtlsv1(Some("hello"));

    let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(), stream).unwrap();
    let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(),
                                                stream).unwrap();
    let mut buf = [0u8;100];
    assert!(stream.read(&mut buf).is_ok());
}
@@ -817,67 +825,40 @@ fn test_read_dtlsv1() {
#[cfg(feature = "sslv2")]
fn test_sslv2_connect_failure() {
    let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]);
    SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap();
    SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(),
                               tcp).err().unwrap();
}

#[cfg(target_os = "linux")]
mod nonblocking_tests {
    extern crate nix;

    use std::io::Write;
    use std::net::TcpStream;
    use std::os::unix::io::AsRawFd;

    use super::Server;
    use self::nix::sys::epoll;
    use self::nix::fcntl;
    use ssl;
    use ssl::error::NonblockingSslError;
    use ssl::SslMethod;
    use ssl::SslMethod::Sslv23;
    use ssl::{SslContext, NonblockingSslStream};
fn wait_io(stream: &NonblockingSslStream<TcpStream>,
           read: bool,
           timeout_ms: u32) -> bool {
    unsafe {
        let mut set: select::fd_set = mem::zeroed();
        select::fd_set(&mut set, stream.get_ref());

    fn wait_io(stream: &NonblockingSslStream<TcpStream>, read: bool, timeout_ms: isize) -> bool {
        let fd = stream.as_raw_fd();
        let ep = epoll::epoll_create().unwrap();
        let event = if read {
            epoll::EpollEvent {
                events: epoll::EPOLLIN | epoll::EPOLLERR,
                data: 0,
        let write = if read {0 as *mut _} else {&mut set as *mut _};
        let read = if !read {0 as *mut _} else {&mut set as *mut _};
        select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms)
               .unwrap()
    }
        } else {
            epoll::EpollEvent {
                events: epoll::EPOLLOUT,
                data: 0,
            }
        };
        epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap();
        let mut events = [event];
        let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap();
        epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap();
        assert!(count <= 1);
        count == 1
    }

    fn make_nonblocking(stream: &TcpStream) {
        let fd = stream.as_raw_fd();
        fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap();
}

#[test]
fn test_write_nonblocking() {
    let (_s, stream) = Server::new();
        make_nonblocking(&stream);
        let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
    stream.set_nonblocking(true).unwrap();
    let cx = SslContext::new(Sslv23).unwrap();
    let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap();

    let mut iterations = 0;
    loop {
        iterations += 1;
        if iterations > 7 {
                // Probably a safe assumption for the foreseeable future of openssl.
            // Probably a safe assumption for the foreseeable future of
            // openssl.
            panic!("Too many read/write round trips in handshake!!");
        }
            let result = stream.write("hello".as_bytes());
        let result = stream.write(b"hello");
        match result {
            Ok(_) => {
                break;
@@ -894,25 +875,27 @@ mod nonblocking_tests {
        }
    }

        // Second write should succeed immediately--plenty of space in kernel buffer,
        // and handshake just completed.
    // Second write should succeed immediately--plenty of space in kernel
    // buffer, and handshake just completed.
    stream.write(" there".as_bytes()).unwrap();
}

#[test]
fn test_read_nonblocking() {
    let (_s, stream) = Server::new();
        make_nonblocking(&stream);
        let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
    stream.set_nonblocking(true).unwrap();
    let cx = SslContext::new(Sslv23).unwrap();
    let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap();

    let mut iterations = 0;
    loop {
        iterations += 1;
        if iterations > 7 {
                // Probably a safe assumption for the foreseeable future of openssl.
            // Probably a safe assumption for the foreseeable future of
            // openssl.
            panic!("Too many read/write round trips in handshake!!");
        }
            let result = stream.write("GET /\r\n\r\n".as_bytes());
        let result = stream.write(b"GET /\r\n\r\n");
        match result {
            Ok(n) => {
                assert_eq!(n, 9);
@@ -950,4 +933,3 @@ mod nonblocking_tests {
    assert!(bytes_read >= 5);
    assert_eq!(&input_buffer[..5], b"HTTP/");
}
}
+87 −0
Original line number Diff line number Diff line
use libc;
pub use self::imp::*;

extern "system" {
    #[link_name = "select"]
    fn raw_select(nfds: libc::c_int,
                  readfds: *mut fd_set,
                  writefds: *mut fd_set,
                  errorfds: *mut fd_set,
                  timeout: *mut libc::timeval) -> libc::c_int;
}

#[cfg(unix)]
mod imp {
    use std::os::unix::prelude::*;
    use std::io;
    use libc;

    const FD_SETSIZE: usize = 1024;

    #[repr(C)]
    pub struct fd_set {
        fds_bits: [u64; FD_SETSIZE / 64]
    }

    pub fn fd_set<F: AsRawFd>(set: &mut fd_set, f: &F) {
        let fd = f.as_raw_fd() as usize;
        set.fds_bits[fd / 64] |= 1 << (fd % 64);
    }

    pub unsafe fn select<F: AsRawFd>(max: &F,
                                     read: *mut fd_set,
                                     write: *mut fd_set,
                                     error: *mut fd_set,
                                     timeout_ms: u32)
                                     -> io::Result<bool> {
        let mut timeout = libc::timeval {
            tv_sec: (timeout_ms / 1000) as libc::time_t,
            tv_usec: (timeout_ms % 1000 * 1000) as libc::suseconds_t,
        };
        let rc = super::raw_select(max.as_raw_fd() + 1, read, write, error,
                                   &mut timeout);
        if rc < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(rc != 0)
        }
    }
}

#[cfg(windows)]
mod imp {
    use std::os::windows::prelude::*;
    use std::io;
    use libc::{SOCKET, c_uint, c_long, timeval};

    const FD_SETSIZE: usize = 64;

    #[repr(C)]
    pub struct fd_set {
        fd_count: c_uint,
        fd_array: [SOCKET; FD_SETSIZE],
    }

    pub fn fd_set<F: AsRawSocket>(set: &mut fd_set, f: &F) {
        set.fd_array[set.fd_count as usize] = f.as_raw_socket();
        set.fd_count += 1;
    }

    pub unsafe fn select<F: AsRawSocket>(_max: &F,
                                         read: *mut fd_set,
                                         write: *mut fd_set,
                                         error: *mut fd_set,
                                         timeout_ms: u32)
                                         -> io::Result<bool> {
        let mut timeout = timeval {
            tv_sec: (timeout_ms / 1000) as c_long,
            tv_usec: (timeout_ms % 1000 * 1000) as c_long,
        };
        let rc = super::raw_select(1, read, write, error, &mut timeout);
        if rc < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(rc != 0)
        }
    }
}