Commit 1df131ff authored by Steven Fackler's avatar Steven Fackler
Browse files

Build out a new error type

parent aa37dba0
Loading
Loading
Loading
Loading
+136 −12
Original line number Diff line number Diff line
@@ -3,12 +3,136 @@ pub use self::OpensslError::*;

use libc::c_ulong;
use std::error;
use std::error::Error as StdError;
use std::fmt;
use std::ffi::CStr;
use std::io;
use std::str;

use ffi;

/// An SSL error.
#[derive(Debug)]
pub enum Error {
    /// The SSL session has been closed by the other end
    ZeroReturn,
    /// An attempt to read data from the underlying socket returned
    /// `WouldBlock`. Wait for read readiness and reattempt the operation.
    WantRead(io::Error),
    /// An attempt to write data from the underlying socket returned
    /// `WouldBlock`. Wait for write readiness and reattempt the operation.
    WantWrite(io::Error),
    #[doc(hidden)] // unused for now
    WantX509Lookup,
    /// An error reported by the underlying stream.
    Stream(io::Error),
    /// An error in the OpenSSL library.
    Ssl(Vec<OpenSslError>),
}

impl fmt::Display for Error {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        try!(fmt.write_str(self.description()));
        match *self {
            Error::Stream(ref err) => write!(fmt, ": {}", err),
            Error::WantRead(ref err) => write!(fmt, ": {}", err),
            Error::WantWrite(ref err) => write!(fmt, ": {}", err),
            Error::Ssl(ref errs) => {
                let mut first = true;
                for err in errs {
                    if first {
                        try!(fmt.write_str(": "));
                        first = false;
                    } else {
                        try!(fmt.write_str(", "));
                    }
                    try!(fmt.write_str(&err.reason()))
                }
                Ok(())
            }
            _ => Ok(())
        }
    }
}

impl error::Error for Error {
    fn description(&self) -> &str {
        match *self {
            Error::ZeroReturn => "The SSL session was closed by the other end",
            Error::WantRead(_) => "A read attempt returned a `WouldBlock` error",
            Error::WantWrite(_) => "A write attempt returned a `WouldBlock` error",
            Error::WantX509Lookup => "The client certificate callback requested to be called again",
            Error::Stream(_) => "The underlying stream reported an error",
            Error::Ssl(_) => "The OpenSSL library reported an error",
        }
    }

    fn cause(&self) -> Option<&error::Error> {
        match *self {
            Error::WantRead(ref err) => Some(err),
            Error::WantWrite(ref err) => Some(err),
            Error::Stream(ref err) => Some(err),
            _ => None
        }
    }
}

/// An error reported from OpenSSL.
pub struct OpenSslError(c_ulong);

impl OpenSslError {
    /// Returns the contents of the OpenSSL error stack.
    pub fn get_stack() -> Vec<OpenSslError> {
        ffi::init();

        let mut errs = vec!();
        loop {
            match unsafe { ffi::ERR_get_error() } {
                0 => break,
                err => errs.push(OpenSslError(err))
            }
        }
        errs
    }

    /// Returns the name of the library reporting the error.
    pub fn library(&self) -> &'static str {
        get_lib(self.0)
    }

    /// Returns the name of the function reporting the error.
    pub fn function(&self) -> &'static str {
        get_func(self.0)
    }

    /// Returns the reason for the error.
    pub fn reason(&self) -> &'static str {
        get_reason(self.0)
    }
}

impl fmt::Debug for OpenSslError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.debug_struct("OpenSslError")
           .field("library", &self.library())
           .field("function", &self.function())
           .field("reason", &self.reason())
           .finish()
    }
}

impl fmt::Display for OpenSslError {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.write_str(&self.reason())
    }
}

impl error::Error for OpenSslError {
    fn description(&self) -> &str {
        "An OpenSSL error"
    }
}

/// An SSL error
#[derive(Debug)]
pub enum SslError {
@@ -115,27 +239,27 @@ pub enum OpensslError {
    }
}

fn get_lib(err: c_ulong) -> String {
fn get_lib(err: c_ulong) -> &'static str {
    unsafe {
        let cstr = ffi::ERR_lib_error_string(err);
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec();
        String::from_utf8(bytes).unwrap()
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes();
        str::from_utf8(bytes).unwrap()
    }
}

fn get_func(err: c_ulong) -> String {
fn get_func(err: c_ulong) -> &'static str {
    unsafe {
        let cstr = ffi::ERR_func_error_string(err);
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec();
        String::from_utf8(bytes).unwrap()
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes();
        str::from_utf8(bytes).unwrap()
    }
}

fn get_reason(err: c_ulong) -> String {
fn get_reason(err: c_ulong) -> &'static str {
    unsafe {
        let cstr = ffi::ERR_reason_error_string(err);
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec();
        String::from_utf8(bytes).unwrap()
        let bytes = CStr::from_ptr(cstr as *const _).to_bytes();
        str::from_utf8(bytes).unwrap()
    }
}

@@ -161,9 +285,9 @@ impl SslError {
    fn from_error_code(err: c_ulong) -> OpensslError {
        ffi::init();
        UnknownError {
            library: get_lib(err),
            function: get_func(err),
            reason: get_reason(err)
            library: get_lib(err).to_owned(),
            function: get_func(err).to_owned(),
            reason: get_reason(err).to_owned()
        }
    }
}
+95 −23
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ use std::marker::PhantomData;
use ffi;
use ffi_extras;
use dh::DH;
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors};
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError};
use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;

@@ -31,6 +31,9 @@ mod bio;
#[cfg(test)]
mod tests;

#[doc(inline)]
pub use ssl::error::Error;

extern "C" {
    fn rust_SSL_clone(ssl: *mut ffi::SSL);
}
@@ -954,7 +957,17 @@ impl<S: Read+Write> SslStream<S> {
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
            match stream.make_old_error(ret) {
                SslError::StreamError(e) => {
                    // This is fine - nonblocking sockets will finish the handshake in read/write
                    if e.kind() == io::ErrorKind::WouldBlock {
                        Ok(stream)
                    } else {
                        Err(SslError::StreamError(e))
                    }
                }
                e => Err(e)
            }
        }
    }

@@ -966,7 +979,17 @@ impl<S: Read+Write> SslStream<S> {
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
            match stream.make_old_error(ret) {
                SslError::StreamError(e) => {
                    // This is fine - nonblocking sockets will finish the handshake in read/write
                    if e.kind() == io::ErrorKind::WouldBlock {
                        Ok(stream)
                    } else {
                        Err(SslError::StreamError(e))
                    }
                }
                e => Err(e)
            }
        }
    }

@@ -986,7 +1009,31 @@ impl<S: Read+Write> SslStream<S> {
}

impl<S> SslStream<S> {
    fn make_error(&mut self, ret: c_int) -> SslError {
    fn make_error(&mut self, ret: c_int) -> Error {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()),
            LibSslError::ErrorSyscall => {
                let errs = OpenSslError::get_stack();
                if errs.is_empty() {
                    if ret == 0 {
                        Error::Stream(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                     "unexpected EOF observed"))
                    } else {
                        Error::Stream(self.get_bio_error())
                    }
                } else {
                    Error::Ssl(errs)
                }
            }
            LibSslError::ErrorZeroReturn => Error::ZeroReturn,
            LibSslError::ErrorWantWrite => Error::WantWrite(self.get_bio_error()),
            LibSslError::ErrorWantRead => Error::WantRead(self.get_bio_error()),
            err => Error::Stream(io::Error::new(io::ErrorKind::Other,
                                                format!("unexpected error {:?}", err))),
        }
    }

    fn make_old_error(&mut self, ret: c_int) -> SslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => SslError::get(),
            LibSslError::ErrorSyscall => {
@@ -1045,6 +1092,32 @@ impl<S> SslStream<S> {
        }
    }

    /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`.
    ///
    /// This is particularly useful with a nonblocking socket, where the error
    /// value will identify if OpenSSL is waiting on read or write readiness.
    pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        }
    }

    /// Like `write`, but returns an `ssl::Error` rather than an `io::Error`.
    ///
    /// This is particularly useful with a nonblocking socket, where the error
    /// value will identify if OpenSSL is waiting on read or write readiness.
    pub fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, Error> {
        let ret = self.ssl.write(buf);
        if ret >= 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        }
    }

    /// Returns the OpenSSL `Ssl` object associated with this stream.
    pub fn ssl(&self) -> &Ssl {
        &self.ssl
@@ -1061,30 +1134,27 @@ impl SslStream<::std::net::TcpStream> {

impl<S: Read> Read for SslStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::SslSessionClosed => Ok(0),
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        match self.ssl_read(buf) {
            Ok(n) => Ok(n),
            Err(Error::ZeroReturn) => Ok(0),
            Err(Error::Stream(e)) => Err(e),
            Err(Error::WantRead(e)) => Err(e),
            Err(Error::WantWrite(e)) => Err(e),
            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
        }
    }
}

impl<S: Write> Write for SslStream<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let ret = self.ssl.write(buf);
        if ret > 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        self.ssl_write(buf).map_err(|e| {
            match e {
                Error::Stream(e) => e,
                Error::WantRead(e) => e,
                Error::WantWrite(e) => e,
                e => io::Error::new(io::ErrorKind::Other, e),
            }
        })
    }

    fn flush(&mut self) -> io::Result<()> {
@@ -1174,7 +1244,9 @@ impl MaybeSslStream<net::TcpStream> {
    }
}

/// An SSL stream wrapping a nonblocking socket.
/// # Deprecated
///
/// Use `SslStream` with `ssl_read` and `ssl_write`.
#[derive(Clone)]
pub struct NonblockingSslStream<S> {
    stream: S,