Commit 9b235a7b authored by Steven Fackler's avatar Steven Fackler
Browse files

Prepare for direct stream support

parent c722f889
Loading
Loading
Loading
Loading
+165 −78
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ use std::ffi::{CStr, CString};
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::iter;
use std::mem;
use std::net;
use std::path::Path;
@@ -740,52 +739,181 @@ make_LibSslError! {
    ErrorWantAccept = SSL_ERROR_WANT_ACCEPT
}

struct IndirectStream<S> {
    stream: S,
    ssl: Arc<Ssl>,
    // Max TLS record size is 16k
    buf: Box<[u8; 16 * 1024]>,
}

impl<S: Clone> Clone for IndirectStream<S> {
    fn clone(&self) -> IndirectStream<S> {
        IndirectStream {
            stream: self.stream.clone(),
            ssl: self.ssl.clone(),
            buf: Box::new(*self.buf)
        }
    }
}

impl IndirectStream<net::TcpStream> {
    fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> {
        Ok(IndirectStream {
            stream: try!(self.stream.try_clone()),
            ssl: self.ssl.clone(),
            buf: Box::new(*self.buf)
        })
    }
}

impl<S: Read+Write> IndirectStream<S> {
    fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        Ok(IndirectStream {
            stream: stream,
            ssl: Arc::new(ssl),
            buf: Box::new([0; 16 * 1024]),
        })
    }

    fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
        let mut ssl = try!(IndirectStream::new_base(ssl, stream));
        try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
        Ok(ssl)
    }

    fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
        let mut ssl = try!(IndirectStream::new_base(ssl, stream));
        try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
        Ok(ssl)
    }

    fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError>
            where F: FnMut(&Ssl) -> c_int {
        loop {
            let ret = blk(&self.ssl);
            if ret > 0 {
                return Ok(ret);
            }

            let e = self.ssl.get_error(ret);
            match e {
                LibSslError::ErrorWantRead => {
                    try_ssl_stream!(self.flush());
                    let len = try_ssl_stream!(self.stream.read(&mut self.buf[..]));
                    if len == 0 {
                        self.ssl.get_rbio().set_eof(true);
                    } else {
                        try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len]));
                    }
                }
                LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) }
                LibSslError::ErrorZeroReturn => return Err(SslSessionClosed),
                LibSslError::ErrorSsl => return Err(SslError::get()),
                LibSslError::ErrorSyscall if ret == 0 => return Ok(0),
                err => panic!("unexpected error {:?} with ret {}", err, ret),
            }
        }
    }

    fn write_through(&mut self) -> io::Result<()> {
        io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
    }
}

impl<S: Read+Write> Read for IndirectStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
            Ok(len) => Ok(len as usize),
            Err(SslSessionClosed) => Ok(0),
            Err(StreamError(e)) => Err(e),
            Err(e @ OpenSslErrors(_)) => {
                Err(io::Error::new(io::ErrorKind::Other, e))
            }
        }
    }
}

impl<S: Read+Write> Write for IndirectStream<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
            Ok(len) => len as usize,
            Err(SslSessionClosed) => 0,
            Err(StreamError(e)) => return Err(e),
            Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
        };
        try!(self.write_through());
        Ok(count)
    }

    fn flush(&mut self) -> io::Result<()> {
        try!(self.write_through());
        self.stream.flush()
    }
}

#[derive(Clone)]
enum StreamKind<S> {
    Indirect(IndirectStream<S>),
}

impl<S> StreamKind<S> {
    fn stream(&self) -> &S {
        match *self {
            StreamKind::Indirect(ref s) => &s.stream
        }
    }

    fn mut_stream(&mut self) -> &mut S {
        match *self {
            StreamKind::Indirect(ref mut s) => &mut s.stream
        }
    }

    fn ssl(&self) -> &Ssl {
        match *self {
            StreamKind::Indirect(ref s) => &s.ssl
        }
    }
}

/// A stream wrapper which handles SSL encryption for an underlying stream.
#[derive(Clone)]
pub struct SslStream<S> {
    stream: S,
    ssl: Arc<Ssl>,
    buf: Vec<u8>
    kind: StreamKind<S>,
}

impl SslStream<net::TcpStream> {
    /// Create a new independently owned handle to the underlying socket.
    pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
        let kind = match self.kind {
            StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
        };
        Ok(SslStream {
            stream: try!(self.stream.try_clone()),
            ssl: self.ssl.clone(),
            buf: self.buf.clone(),
            kind: kind
        })
    }
}

impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl)
        write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl())
    }
}

impl<S: Read+Write> SslStream<S> {
    fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
    pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let stream = try!(IndirectStream::new_client(ssl, stream));
        Ok(SslStream {
            stream: stream,
            ssl: Arc::new(ssl),
            // Maximum TLS record size is 16k
            buf: iter::repeat(0).take(16 * 1024).collect(),
            kind: StreamKind::Indirect(stream)
        })
    }

    pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let mut ssl = try!(SslStream::new_base(ssl, stream));
        try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
        Ok(ssl)
    }

    pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let mut ssl = try!(SslStream::new_base(ssl, stream));
        try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
        Ok(ssl)
        let stream = try!(IndirectStream::new_server(ssl, stream));
        Ok(SslStream {
            kind: StreamKind::Indirect(stream)
        })
    }

    /// # Deprecated
@@ -811,12 +939,12 @@ impl<S: Read+Write> SslStream<S> {

    /// Returns a reference to the underlying stream.
    pub fn get_ref(&self) -> &S {
        &self.stream
        self.kind.stream()
    }

    /// Return the certificate of the peer
    pub fn get_peer_certificate(&self) -> Option<X509> {
        self.ssl.get_peer_certificate()
        self.kind.ssl().get_peer_certificate()
    }

    /// Returns a mutable reference to the underlying stream.
@@ -826,46 +954,14 @@ impl<S: Read+Write> SslStream<S> {
    /// It is inadvisable to read from or write to the underlying stream as it
    /// will most likely corrupt the SSL session.
    pub fn get_mut(&mut self) -> &mut S {
        &mut self.stream
    }

    fn in_retry_wrapper<F>(&mut self, mut blk: F)
            -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int {
        loop {
            let ret = blk(&self.ssl);
            if ret > 0 {
                return Ok(ret);
            }

            let e = self.ssl.get_error(ret);
            match e {
                LibSslError::ErrorWantRead => {
                    try_ssl_stream!(self.flush());
                    let len = try_ssl_stream!(self.stream.read(&mut self.buf[..]));
                    if len == 0 {
                        self.ssl.get_rbio().set_eof(true);
                    } else {
                        try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len]));
                    }
                }
                LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) }
                LibSslError::ErrorZeroReturn => return Err(SslSessionClosed),
                LibSslError::ErrorSsl => return Err(SslError::get()),
                LibSslError::ErrorSyscall if ret == 0 => return Ok(0),
                err => panic!("unexpected error {:?} with ret {}", err, ret),
            }
        }
    }

    fn write_through(&mut self) -> io::Result<()> {
        io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
        self.kind.mut_stream()
    }

    /// Get the compression currently in use.  The result will be
    /// either None, indicating no compression is in use, or a string
    /// with the compression name.
    pub fn get_compression(&self) -> Option<String> {
        let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) };
        let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) };
        if ptr == ptr::null() {
            return None;
        }
@@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> {
    /// This method needs the `npn` feature.
    #[cfg(feature = "npn")]
    pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> {
        self.ssl.get_selected_npn_protocol()
        self.kind.ssl().get_selected_npn_protocol()
    }

    /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any).
    pub fn pending(&self) -> usize {
        self.ssl.pending()
        self.kind.ssl().pending()
    }
}

impl<S: Read+Write> Read for SslStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
            Ok(len) => Ok(len as usize),
            Err(SslSessionClosed) => Ok(0),
            Err(StreamError(e)) => Err(e),
            Err(e @ OpenSslErrors(_)) => {
                Err(io::Error::new(io::ErrorKind::Other, e))
            }
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.read(buf)
        }
    }
}

impl<S: Read+Write> Write for SslStream<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
            Ok(len) => len as usize,
            Err(SslSessionClosed) => 0,
            Err(StreamError(e)) => return Err(e),
            Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
        };
        try!(self.write_through());
        Ok(count)
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.write(buf)
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        try!(self.write_through());
        self.stream.flush()
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.flush()
        }
    }
}