Commit d6ce9afd authored by Steven Fackler's avatar Steven Fackler
Browse files

Have NonblockingSslStream delegate to SslStream

parent 1df131ff
Loading
Loading
Loading
Loading
+18 −11
Original line number Diff line number Diff line
@@ -95,6 +95,11 @@ impl OpenSslError {
        errs
    }

    /// Returns the raw OpenSSL error code for this error.
    pub fn error_code(&self) -> c_ulong {
        self.0
    }

    /// Returns the name of the library reporting the error.
    pub fn library(&self) -> &'static str {
        get_lib(self.0)
@@ -239,6 +244,17 @@ pub enum OpensslError {
    }
}

impl OpensslError {
    pub fn from_error_code(err: c_ulong) -> OpensslError {
        ffi::init();
        UnknownError {
            library: get_lib(err).to_owned(),
            function: get_func(err).to_owned(),
            reason: get_reason(err).to_owned()
        }
    }
}

fn get_lib(err: c_ulong) -> &'static str {
    unsafe {
        let cstr = ffi::ERR_lib_error_string(err);
@@ -271,7 +287,7 @@ impl SslError {
        loop {
            match unsafe { ffi::ERR_get_error() } {
                0 => break,
                err => errs.push(SslError::from_error_code(err))
                err => errs.push(OpensslError::from_error_code(err))
            }
        }
        OpenSslErrors(errs)
@@ -279,16 +295,7 @@ impl SslError {

    /// Creates an `SslError` from the raw numeric error code.
    pub fn from_error(err: c_ulong) -> SslError {
        OpenSslErrors(vec![SslError::from_error_code(err)])
    }

    fn from_error_code(err: c_ulong) -> OpensslError {
        ffi::init();
        UnknownError {
            library: get_lib(err).to_owned(),
            function: get_func(err).to_owned(),
            reason: get_reason(err).to_owned()
        }
        OpenSslErrors(vec![OpensslError::from_error_code(err)])
    }
}

+66 −147
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ use std::str;
use std::net;
use std::path::Path;
use std::ptr;
use std::sync::{Once, ONCE_INIT, Arc, Mutex};
use std::sync::{Once, ONCE_INIT, Mutex};
use std::cmp;
use std::any::Any;
#[cfg(any(feature = "npn", feature = "alpn"))]
@@ -18,11 +18,16 @@ use libc::{c_uchar, c_uint};
#[cfg(any(feature = "npn", feature = "alpn"))]
use std::slice;
use std::marker::PhantomData;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};

use ffi;
use ffi_extras;
use dh::DH;
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError};
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError,
                 OpensslError};
use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;

@@ -935,6 +940,20 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
    }
}

#[cfg(unix)]
impl<S: AsRawFd> AsRawFd for SslStream<S> {
    fn as_raw_fd(&self) -> RawFd {
        self.get_ref().as_raw_fd()
    }
}

#[cfg(windows)]
impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
    fn as_raw_fd(&self) -> RawSocket {
        self.0.as_raw_socket()
    }
}

impl<S: Read+Write> SslStream<S> {
    fn new_base(ssl: Ssl, stream: S) -> Self {
        unsafe {
@@ -1247,65 +1266,38 @@ impl MaybeSslStream<net::TcpStream> {
/// # Deprecated
///
/// Use `SslStream` with `ssl_read` and `ssl_write`.
#[derive(Clone)]
pub struct NonblockingSslStream<S> {
    stream: S,
    ssl: Arc<Ssl>,
}
pub struct NonblockingSslStream<S>(SslStream<S>);

impl NonblockingSslStream<net::TcpStream> {
    pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
        Ok(NonblockingSslStream {
            stream: try!(self.stream.try_clone()),
            ssl: self.ssl.clone(),
        })
impl<S: Clone + Read + Write> Clone for NonblockingSslStream<S> {
    fn clone(&self) -> Self {
        NonblockingSslStream(self.0.clone())
    }
}

impl<S> NonblockingSslStream<S> {
    fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> {
        unsafe {
            let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0));
            ffi_extras::BIO_set_nbio(bio, 1);
            ffi::SSL_set_bio(ssl.ssl, bio, bio);
#[cfg(unix)]
impl<S: AsRawFd> AsRawFd for NonblockingSslStream<S> {
    fn as_raw_fd(&self) -> RawFd {
        self.0.as_raw_fd()
    }

        Ok(NonblockingSslStream {
            stream: stream,
            ssl: Arc::new(ssl),
        })
}

    fn make_error(&self, ret: c_int) -> NonblockingSslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()),
            LibSslError::ErrorSyscall => {
                let err = SslError::get();
                let count = match err {
                    SslError::OpenSslErrors(ref v) => v.len(),
                    _ => unreachable!(),
                };
                let ssl_error = if count == 0 {
                    if ret == 0 {
                        SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                             "unexpected EOF observed"))
                    } else {
                        SslError::StreamError(io::Error::last_os_error())
#[cfg(windows)]
impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
    fn as_raw_fd(&self) -> RawSocket {
        self.0.as_raw_socket()
    }
                } else {
                    err
                };
                ssl_error.into()
            },
            LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite,
            LibSslError::ErrorWantRead => NonblockingSslError::WantRead,
            err => panic!("unexpected error {:?} with ret {}", err, ret),
}

impl NonblockingSslStream<net::TcpStream> {
    pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
        self.0.try_clone().map(NonblockingSslStream)
    }
}

impl<S> NonblockingSslStream<S> {
    /// Returns a reference to the underlying stream.
    pub fn get_ref(&self) -> &S {
        &self.stream
        self.0.get_ref()
    }

    /// Returns a mutable reference to the underlying stream.
@@ -1315,37 +1307,23 @@ impl<S> NonblockingSslStream<S> {
    /// It is inadvisable to read from or write to the underlying stream as it
    /// will most likely corrupt the SSL session.
    pub fn get_mut(&mut self) -> &mut S {
        &mut self.stream
        self.0.get_mut()
    }

    /// Returns a reference to the Ssl.
    pub fn ssl(&self) -> &Ssl {
        &self.ssl
        self.0.ssl()
    }
}

#[cfg(unix)]
impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
impl<S: Read+Write> NonblockingSslStream<S> {
    /// Create a new nonblocking client ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.connect();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
        SslStream::connect(ssl, stream).map(NonblockingSslStream)
    }

    /// Create a new nonblocking server ssl connection on wrapped `stream`.
@@ -1354,78 +1332,25 @@ impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.accept();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
    }
}

#[cfg(unix)]
impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> {
    fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd {
        self.stream.as_raw_fd()
    }
}

#[cfg(windows)]
impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> {
    /// Create a new nonblocking client ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_socket() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.connect();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
        }
        SslStream::accept(ssl, stream).map(NonblockingSslStream)
    }

    /// Create a new nonblocking server ssl connection on wrapped `stream`.
    ///
    /// Note that this method will most likely not actually complete the SSL
    /// handshake because doing so requires several round trips; the handshake will
    /// be completed in subsequent read/write calls managed by your event loop.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_socket() as c_int;
        let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
        let ret = ssl.ssl.accept();
        if ret > 0 {
            Ok(ssl)
        } else {
            // WantRead/WantWrite is okay here; we'll finish the handshake in
            // subsequent send/recv calls.
            match ssl.make_error(ret) {
                NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
                NonblockingSslError::SslError(other) => Err(other),
            }
    fn convert_err(&self, err: Error) -> NonblockingSslError {
        match err {
            Error::ZeroReturn => SslError::SslSessionClosed.into(),
            Error::WantRead(_) => NonblockingSslError::WantRead,
            Error::WantWrite(_) => NonblockingSslError::WantWrite,
            Error::WantX509Lookup => unreachable!(),
            Error::Stream(e) => SslError::StreamError(e).into(),
            Error::Ssl(e) => {
                SslError::OpenSslErrors(e.iter()
                                         .map(|e| OpensslError::from_error_code(e.error_code()))
                                         .collect())
                              .into()
            }
        }
    }

impl<S: Read+Write> NonblockingSslStream<S> {
    /// Read bytes from the SSL stream into `buf`.
    ///
    /// Given the SSL state machine, this method may return either `WantWrite`
@@ -1442,11 +1367,10 @@ impl<S: Read+Write> NonblockingSslStream<S> {
    /// On a return value of `Ok(count)`, count is the number of decrypted
    /// plaintext bytes copied into the `buf` slice.
    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        match self.0.ssl_read(buf) {
            Ok(n) => Ok(n),
            Err(Error::ZeroReturn) => Ok(0),
            Err(e) => Err(self.convert_err(e))
        }
    }

@@ -1466,11 +1390,6 @@ impl<S: Read+Write> NonblockingSslStream<S> {
    /// Given a return value of `Ok(count)`, count is the number of plaintext bytes
    /// from the `buf` slice that were encrypted and written onto the stream.
    pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
        let ret = self.ssl.write(buf);
        if ret > 0 {
            Ok(ret as usize)
        } else {
            Err(self.make_error(ret))
        }
        self.0.ssl_write(buf).map_err(|e| self.convert_err(e))
    }
}