Unverified Commit dcfe1dfa authored by Steven Fackler's avatar Steven Fackler Committed by GitHub
Browse files

Merge pull request #802 from sfackler/ssl-error

Overhaul ssl error
parents 642d5bd5 52a06adc
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -27,7 +27,5 @@ openssl-sys = { version = "0.9.23", path = "../openssl-sys" }

[dev-dependencies]
tempdir = "0.3"
winapi = "0.2"
ws2_32-sys = "0.2"
hex = "0.2"
data-encoding = "2.0"
+90 −64
Original line number Diff line number Diff line
use ffi;
use libc::c_int;
use std::error;
use std::error::Error as StdError;
use std::fmt;
@@ -7,90 +9,114 @@ use error::ErrorStack;
use ssl::MidHandshakeSslStream;
use x509::X509VerifyResult;

/// An SSL error.
// FIXME this is missing variants
#[derive(Debug)]
pub enum Error {
    /// The SSL session has been closed by the other end
    ZeroReturn,
    /// An attempt to read data from the underlying socket returned
    /// `WouldBlock`. Wait for read readiness and reattempt the operation.
    WantRead(io::Error),
    /// An attempt to write data from the underlying socket returned
    /// `WouldBlock`. Wait for write readiness and reattempt the operation.
    WantWrite(io::Error),
    /// The client certificate callback requested to be called again.
    WantX509Lookup,
    /// An error reported by the underlying stream.
    Stream(io::Error),
    /// An error in the OpenSSL library.
    Ssl(ErrorStack),
/// An error code returned from SSL functions.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ErrorCode(c_int);

impl ErrorCode {
    pub fn from_raw(raw: c_int) -> ErrorCode {
        ErrorCode(raw)
    }

impl fmt::Display for Error {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.write_str(self.description())?;
        if let Some(err) = self.cause() {
            write!(fmt, ": {}", err)
        } else {
            Ok(())
    pub fn as_raw(&self) -> c_int {
        self.0
    }

    /// The SSL session has been closed.
    pub const ZERO_RETURN: ErrorCode = ErrorCode(ffi::SSL_ERROR_ZERO_RETURN);

    /// An attempt to read data from the underlying socket returned `WouldBlock`.
    ///
    /// Wait for read readiness and retry the operation.
    pub const WANT_READ: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_READ);

    /// An attempt to write data to the underlying socket returned `WouldBlock`.
    ///
    /// Wait for write readiness and retry the operation.
    pub const WANT_WRITE: ErrorCode = ErrorCode(ffi::SSL_ERROR_WANT_WRITE);

    /// A non-recoverable IO error occurred.
    pub const SYSCALL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SYSCALL);

    /// An error occurred in the SSL library.
    pub const SSL: ErrorCode = ErrorCode(ffi::SSL_ERROR_SSL);
}

#[derive(Debug)]
pub(crate) enum InnerError {
    Io(io::Error),
    Ssl(ErrorStack),
}

impl error::Error for Error {
    fn description(&self) -> &str {
        match *self {
            Error::ZeroReturn => "The SSL session was closed by the other end",
            Error::WantRead(_) => "A read attempt returned a `WouldBlock` error",
            Error::WantWrite(_) => "A write attempt returned a `WouldBlock` error",
            Error::WantX509Lookup => "The client certificate callback requested to be called again",
            Error::Stream(_) => "The underlying stream reported an error",
            Error::Ssl(_) => "The OpenSSL library reported an error",
/// An SSL error.
#[derive(Debug)]
pub struct Error {
    pub(crate) code: ErrorCode,
    pub(crate) cause: Option<InnerError>,
}

impl Error {
    pub fn code(&self) -> ErrorCode {
        self.code
    }

    fn cause(&self) -> Option<&error::Error> {
        match *self {
            Error::WantRead(ref err) => Some(err),
            Error::WantWrite(ref err) => Some(err),
            Error::Stream(ref err) => Some(err),
            Error::Ssl(ref err) => Some(err),
    pub fn io_error(&self) -> Option<&io::Error> {
        match self.cause {
            Some(InnerError::Io(ref e)) => Some(e),
            _ => None,
        }
    }
}

impl From<ErrorStack> for Error {
    fn from(e: ErrorStack) -> Error {
        Error::Ssl(e)
    pub fn into_io_error(self) -> Result<io::Error, Error> {
        match self.cause {
            Some(InnerError::Io(e)) => Ok(e),
            _ => Err(self),
        }
    }

/// An error indicating that the operation can be immediately retried.
///
/// OpenSSL's [`SSL_read`] and [`SSL_write`] functions can return `SSL_ERROR_WANT_READ` even when
/// the underlying socket is performing blocking IO in certain cases. When this happens, the
/// the operation can be immediately retried.
///
/// To signal this event, the `io::Error` inside of [`Error::WantRead`] will be constructed around
/// a `RetryError`.
///
/// [`SSL_read`]: https://www.openssl.org/docs/manmaster/man3/SSL_read.html
/// [`SSL_write`]: https://www.openssl.org/docs/manmaster/man3/SSL_write.html
/// [`Error::WantRead`]: enum.Error.html#variant.WantRead
#[derive(Debug)]
pub struct RetryError;
    pub fn ssl_error(&self) -> Option<&ErrorStack> {
        match self.cause {
            Some(InnerError::Ssl(ref e)) => Some(e),
            _ => None,
        }
    }
}

impl fmt::Display for RetryError {
impl fmt::Display for Error {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.write_str(error::Error::description(self))
        match self.code {
            ErrorCode::ZERO_RETURN => fmt.write_str("the SSL session has been shut down"),
            ErrorCode::WANT_READ => match self.io_error() {
                Some(_) => fmt.write_str("a nonblocking read call would have blocked"),
                None => fmt.write_str("the operation should be retried"),
            },
            ErrorCode::SYSCALL => match self.io_error() {
                Some(err) => write!(fmt, "the inner stream returned an error: {}", err),
                None => fmt.write_str("unexpected EOF"),
            },
            ErrorCode::SSL => {
                fmt.write_str("OpenSSL error")?;
                if let Some(ref err) = self.ssl_error() {
                    write!(fmt, ": {}", err)?
                }
                Ok(())
            }
            ErrorCode(code) => write!(fmt, "unknown error code {}", code),
        }
    }
}

impl error::Error for RetryError {
impl error::Error for Error {
    fn description(&self) -> &str {
        "operation must be retried"
        "an OpenSSL error"
    }

    fn cause(&self) -> Option<&error::Error> {
        match self.cause {
            Some(InnerError::Io(ref e)) => Some(e),
            Some(InnerError::Ssl(ref e)) => Some(e),
            None => None,
        }
    }
}

+52 −81
Original line number Diff line number Diff line
@@ -97,18 +97,19 @@ use error::ErrorStack;
use ex_data::Index;
use stack::{Stack, StackRef};
use ssl::bio::BioMethod;
use ssl::error::InnerError;
use ssl::callbacks::*;

pub use ssl::connector::{ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector,
                         SslConnectorBuilder};
pub use ssl::error::{Error, HandshakeError, RetryError};
pub use ssl::error::{Error, ErrorCode, HandshakeError};

mod error;
mod callbacks;
mod connector;
mod bio;
#[cfg(test)]
mod tests;
mod test;

bitflags! {
    /// Options controlling the behavior of an `SslContext`.
@@ -1463,8 +1464,8 @@ impl SslRef {
        unsafe { ffi::SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) }
    }

    fn get_error(&self, ret: c_int) -> c_int {
        unsafe { ffi::SSL_get_error(self.as_ptr(), ret) }
    fn get_error(&self, ret: c_int) -> ErrorCode {
        unsafe { ErrorCode::from_raw(ffi::SSL_get_error(self.as_ptr(), ret)) }
    }

    /// Like [`SslContextBuilder::set_verify`].
@@ -2053,16 +2054,14 @@ impl Ssl {
        if ret > 0 {
            Ok(stream)
        } else {
            match stream.make_error(ret) {
                e @ Error::WantWrite(_) | e @ Error::WantRead(_) => {
                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
                        stream: stream,
                        error: e,
                    }))
                }
                err => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream: stream,
                    error: err,
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
@@ -2087,16 +2086,14 @@ impl Ssl {
        if ret > 0 {
            Ok(stream)
        } else {
            match stream.make_error(ret) {
                e @ Error::WantWrite(_) | e @ Error::WantRead(_) => {
                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
                        stream: stream,
                        error: e,
                    }))
                }
                err => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream: stream,
                    error: err,
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
@@ -2146,15 +2143,12 @@ impl<S> MidHandshakeSslStream<S> {
        if ret > 0 {
            Ok(self.stream)
        } else {
            match self.stream.make_error(ret) {
                e @ Error::WantWrite(_) | e @ Error::WantRead(_) => {
                    self.error = e;
            self.error = self.stream.make_error(ret);
            match self.error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
                    Err(HandshakeError::WouldBlock(self))
                }
                err => {
                    self.error = err;
                    Err(HandshakeError::Failure(self))
                }
                _ => Err(HandshakeError::Failure(self)),
            }
        }
    }
@@ -2225,12 +2219,7 @@ impl<S: Read + Write> SslStream<S> {
        if ret > 0 {
            Ok(ret as usize)
        } else {
            match self.make_error(ret) {
                // FIXME only do this in read
                // Don't treat unexpected EOFs as errors when reading
                Error::Stream(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => Ok(0),
                e => Err(e),
            }
            Err(self.make_error(ret))
        }
    }

@@ -2280,45 +2269,26 @@ impl<S> SslStream<S> {
    fn make_error(&mut self, ret: c_int) -> Error {
        self.check_panic();

        match self.ssl.get_error(ret) {
            ffi::SSL_ERROR_SSL => Error::Ssl(ErrorStack::get()),
            ffi::SSL_ERROR_SYSCALL => {
        let code = self.ssl.get_error(ret);

        let cause = match code {
            ErrorCode::SSL => Some(InnerError::Ssl(ErrorStack::get())),
            ErrorCode::SYSCALL => {
                let errs = ErrorStack::get();
                if errs.errors().is_empty() {
                    match self.get_bio_error() {
                        Some(err) => Error::Stream(err),
                        None => Error::Stream(io::Error::new(
                            io::ErrorKind::ConnectionAborted,
                            "unexpected EOF observed",
                        )),
                    }
                    self.get_bio_error().map(InnerError::Io)
                } else {
                    Error::Ssl(errs)
                    Some(InnerError::Ssl(errs))
                }
            }
            ffi::SSL_ERROR_ZERO_RETURN => Error::ZeroReturn,
            ffi::SSL_ERROR_WANT_WRITE => {
                let err = match self.get_bio_error() {
                    Some(err) => err,
                    None => io::Error::new(
                        io::ErrorKind::Other,
                        "BUG: got an SSL_ERROR_WANT_WRITE with no error in the BIO",
                    ),
                };
                Error::WantWrite(err)
            ErrorCode::ZERO_RETURN => None,
            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
                self.get_bio_error().map(InnerError::Io)
            }
            ffi::SSL_ERROR_WANT_READ => {
                let err = match self.get_bio_error() {
                    Some(err) => err,
                    None => io::Error::new(io::ErrorKind::Other, RetryError),
            _ => None,
        };
                Error::WantRead(err)
            }
            err => Error::Stream(io::Error::new(
                io::ErrorKind::InvalidData,
                format!("unexpected error {}", err),
            )),
        }

        Error { code, cause }
    }

    fn check_panic(&mut self) {
@@ -2363,13 +2333,15 @@ impl<S: Read + Write> Read for SslStream<S> {
        loop {
            match self.ssl_read(buf) {
                Ok(n) => return Ok(n),
                Err(Error::ZeroReturn) => return Ok(0),
                Err(Error::WantRead(ref e))
                    if e.get_ref().map_or(false, |e| e.is::<RetryError>()) => {}
                Err(Error::Stream(e)) | Err(Error::WantRead(e)) | Err(Error::WantWrite(e)) => {
                    return Err(e);
                Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0),
                Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {
                    return Ok(0)
                }
                Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {}
                Err(e) => {
                    return Err(e.into_io_error()
                        .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))
                }
                Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
            }
        }
    }
@@ -2380,12 +2352,11 @@ impl<S: Read + Write> Write for SslStream<S> {
        loop {
            match self.ssl_write(buf) {
                Ok(n) => return Ok(n),
                Err(Error::WantRead(ref e))
                    if e.get_ref().map_or(false, |e| e.is::<RetryError>()) => {}
                Err(Error::Stream(e)) | Err(Error::WantRead(e)) | Err(Error::WantWrite(e)) => {
                    return Err(e);
                Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {}
                Err(e) => {
                    return Err(e.into_io_error()
                        .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))
                }
                Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
            }
        }
    }
+8 −132
Original line number Diff line number Diff line
@@ -27,11 +27,9 @@ use pkey::PKey;

use std::net::UdpSocket;

mod select;

static ROOT_CERT: &'static [u8] = include_bytes!("../../../test/root-ca.pem");
static CERT: &'static [u8] = include_bytes!("../../../test/cert.pem");
static KEY: &'static [u8] = include_bytes!("../../../test/key.pem");
static ROOT_CERT: &'static [u8] = include_bytes!("../../test/root-ca.pem");
static CERT: &'static [u8] = include_bytes!("../../test/cert.pem");
static KEY: &'static [u8] = include_bytes!("../../test/key.pem");

fn next_addr() -> SocketAddr {
    use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
@@ -364,9 +362,9 @@ fn test_write_hits_stream() {

#[test]
fn test_set_certificate_and_private_key() {
    let key = include_bytes!("../../../test/key.pem");
    let key = include_bytes!("../../test/key.pem");
    let key = PKey::private_key_from_pem(key).unwrap();
    let cert = include_bytes!("../../../test/cert.pem");
    let cert = include_bytes!("../../test/cert.pem");
    let cert = X509::from_pem(cert).unwrap();

    let mut ctx = SslContext::builder(SslMethod::tls()).unwrap();
@@ -731,128 +729,6 @@ fn test_alpn_server_select_none() {
    assert_eq!(None, stream.ssl().selected_alpn_protocol());
}

fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool {
    unsafe {
        let mut set: select::fd_set = mem::zeroed();
        select::fd_set(&mut set, stream);

        let write = if read {
            0 as *mut _
        } else {
            &mut set as *mut _
        };
        let read = if !read {
            0 as *mut _
        } else {
            &mut set as *mut _
        };
        select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap()
    }
}

fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>) -> SslStream<TcpStream> {
    match res {
        Ok(s) => s,
        Err(HandshakeError::WouldBlock(s)) => {
            wait_io(s.get_ref(), true, 1_000);
            handshake(s.handshake())
        }
        Err(err) => panic!("error on handshake {:?}", err),
    }
}

#[test]
fn test_write_nonblocking() {
    let (_s, stream) = Server::new();
    stream.set_nonblocking(true).unwrap();
    let cx = SslContext::builder(SslMethod::tls()).unwrap().build();
    let mut stream = handshake(Ssl::new(&cx).unwrap().connect(stream));

    let mut iterations = 0;
    loop {
        iterations += 1;
        if iterations > 7 {
            // Probably a safe assumption for the foreseeable future of
            // openssl.
            panic!("Too many read/write round trips in handshake!!");
        }
        let result = stream.ssl_write(b"hello");
        match result {
            Ok(_) => {
                break;
            }
            Err(Error::WantRead(_)) => {
                assert!(wait_io(stream.get_ref(), true, 1000));
            }
            Err(Error::WantWrite(_)) => {
                assert!(wait_io(stream.get_ref(), false, 1000));
            }
            Err(other) => {
                panic!("Unexpected SSL Error: {:?}", other);
            }
        }
    }

    // Second write should succeed immediately--plenty of space in kernel
    // buffer, and handshake just completed.
    stream.write(" there".as_bytes()).unwrap();
}

#[test]
#[cfg_attr(any(libressl, windows, target_arch = "arm"), ignore)] // FIXME(#467)
fn test_read_nonblocking() {
    let (_s, stream) = Server::new();
    stream.set_nonblocking(true).unwrap();
    let cx = SslContext::builder(SslMethod::tls()).unwrap().build();
    let mut stream = handshake(Ssl::new(&cx).unwrap().connect(stream));

    let mut iterations = 0;
    loop {
        iterations += 1;
        if iterations > 7 {
            // Probably a safe assumption for the foreseeable future of
            // openssl.
            panic!("Too many read/write round trips in handshake!!");
        }
        let result = stream.ssl_write(b"GET /\r\n\r\n");
        match result {
            Ok(n) => {
                assert_eq!(n, 9);
                break;
            }
            Err(Error::WantRead(..)) => {
                assert!(wait_io(stream.get_ref(), true, 1000));
            }
            Err(Error::WantWrite(..)) => {
                assert!(wait_io(stream.get_ref(), false, 1000));
            }
            Err(other) => {
                panic!("Unexpected SSL Error: {:?}", other);
            }
        }
    }
    let mut input_buffer = [0u8; 1500];
    let result = stream.ssl_read(&mut input_buffer);
    let bytes_read = match result {
        Ok(n) => {
            // This branch is unlikely, but on an overloaded VM with
            // unlucky context switching, the response could actually
            // be in the receive buffer before we issue the read() syscall...
            n
        }
        Err(Error::WantRead(..)) => {
            assert!(wait_io(stream.get_ref(), true, 3000));
            // Second read should return application data.
            stream.read(&mut input_buffer).unwrap()
        }
        Err(other) => {
            panic!("Unexpected SSL Error: {:?}", other);
        }
    };
    assert!(bytes_read >= 5);
    assert_eq!(&input_buffer[..5], b"HTTP/");
}

#[test]
#[should_panic(expected = "blammo")]
fn write_panic() {
@@ -974,7 +850,7 @@ fn default_verify_paths() {

#[test]
fn add_extra_chain_cert() {
    let cert = include_bytes!("../../../test/cert.pem");
    let cert = include_bytes!("../../test/cert.pem");
    let cert = X509::from_pem(cert).unwrap();
    let mut ctx = SslContext::builder(SslMethod::tls()).unwrap();
    ctx.add_extra_chain_cert(cert).unwrap();
@@ -1226,7 +1102,7 @@ fn tmp_dh_callback() {
            .unwrap();
        ctx.set_tmp_dh_callback(|_, _, _| {
            CALLED_BACK.store(true, Ordering::SeqCst);
            let dh = include_bytes!("../../../test/dhparams.pem");
            let dh = include_bytes!("../../test/dhparams.pem");
            Dh::from_pem(dh)
        });
        let ssl = Ssl::new(&ctx.build()).unwrap();
@@ -1295,7 +1171,7 @@ fn tmp_dh_callback_ssl() {
        let mut ssl = Ssl::new(&ctx.build()).unwrap();
        ssl.set_tmp_dh_callback(|_, _, _| {
            CALLED_BACK.store(true, Ordering::SeqCst);
            let dh = include_bytes!("../../../test/dhparams.pem");
            let dh = include_bytes!("../../test/dhparams.pem");
            Dh::from_pem(dh)
        });
        ssl.accept(stream).unwrap();

openssl/src/ssl/tests/select.rs

deleted100644 → 0
+0 −74
Original line number Diff line number Diff line
use libc;
pub use self::imp::*;

#[cfg(unix)]
mod imp {
    use std::os::unix::prelude::*;
    use std::io;
    use libc;

    pub use libc::fd_set;

    pub fn fd_set<F: AsRawFd>(set: &mut fd_set, f: &F) {
        unsafe {
            libc::FD_SET(f.as_raw_fd(), set);
        }
    }

    pub unsafe fn select<F: AsRawFd>(
        max: &F,
        read: *mut fd_set,
        write: *mut fd_set,
        error: *mut fd_set,
        timeout_ms: u32,
    ) -> io::Result<bool> {
        let mut timeout = libc::timeval {
            tv_sec: (timeout_ms / 1000) as libc::time_t,
            tv_usec: (timeout_ms % 1000 * 1000) as libc::suseconds_t,
        };
        let rc = libc::select(max.as_raw_fd() + 1, read, write, error, &mut timeout);
        if rc < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(rc != 0)
        }
    }
}

#[cfg(windows)]
mod imp {
    extern crate winapi;
    extern crate ws2_32;

    use std::os::windows::prelude::*;
    use std::io;
    use libc::{c_uint, c_long};
    use self::winapi::SOCKET;
    use self::winapi::winsock2;

    pub use self::winapi::winsock2::fd_set;

    pub fn fd_set<F: AsRawSocket>(set: &mut fd_set, f: &F) {
        set.fd_array[set.fd_count as usize] = f.as_raw_socket();
        set.fd_count += 1;
    }

    pub unsafe fn select<F: AsRawSocket>(
        _max: &F,
        read: *mut fd_set,
        write: *mut fd_set,
        error: *mut fd_set,
        timeout_ms: u32,
    ) -> io::Result<bool> {
        let mut timeout = winsock2::timeval {
            tv_sec: (timeout_ms / 1000) as c_long,
            tv_usec: (timeout_ms % 1000 * 1000) as c_long,
        };
        let rc = ws2_32::select(1, read, write, error, &mut timeout);
        if rc < 0 {
            Err(io::Error::last_os_error())
        } else {
            Ok(rc != 0)
        }
    }
}