diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index cb4448b82b220d58ee74a7e53c263deeb2487d08..18acf7f88990e0e28060b3ce27a2bf11d006629f 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,7 +5,6 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; -use std::iter; use std::mem; use std::net; use std::path::Path; @@ -740,52 +739,181 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } +struct IndirectStream { + stream: S, + ssl: Arc, + // Max TLS record size is 16k + buf: Box<[u8; 16 * 1024]>, +} + +impl Clone for IndirectStream { + fn clone(&self) -> IndirectStream { + IndirectStream { + stream: self.stream.clone(), + ssl: self.ssl.clone(), + buf: Box::new(*self.buf) + } + } +} + +impl IndirectStream { + fn try_clone(&self) -> io::Result> { + Ok(IndirectStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + buf: Box::new(*self.buf) + }) + } +} + +impl IndirectStream { + fn new_base(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + Ok(IndirectStream { + stream: stream, + ssl: Arc::new(ssl), + buf: Box::new([0; 16 * 1024]), + }) + } + + fn new_client(ssl: T, stream: S) -> Result, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); + try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); + Ok(ssl) + } + + fn new_server(ssl: T, stream: S) -> Result, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); + try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); + Ok(ssl) + } + + fn in_retry_wrapper(&mut self, mut blk: F) -> Result + where F: FnMut(&Ssl) -> c_int { + loop { + let ret = blk(&self.ssl); + if ret > 0 { + return Ok(ret); + } + + let e = self.ssl.get_error(ret); + match e { + LibSslError::ErrorWantRead => { + try_ssl_stream!(self.flush()); + let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); + if len == 0 { + self.ssl.get_rbio().set_eof(true); + } else { + try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); + } + } + LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } + LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), + LibSslError::ErrorSsl => return Err(SslError::get()), + LibSslError::ErrorSyscall if ret == 0 => return Ok(0), + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } + } + + fn write_through(&mut self) -> io::Result<()> { + io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) + } +} + +impl Read for IndirectStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { + Ok(len) => Ok(len as usize), + Err(SslSessionClosed) => Ok(0), + Err(StreamError(e)) => Err(e), + Err(e @ OpenSslErrors(_)) => { + Err(io::Error::new(io::ErrorKind::Other, e)) + } + } + } +} + +impl Write for IndirectStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { + Ok(len) => len as usize, + Err(SslSessionClosed) => 0, + Err(StreamError(e)) => return Err(e), + Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), + }; + try!(self.write_through()); + Ok(count) + } + + fn flush(&mut self) -> io::Result<()> { + try!(self.write_through()); + self.stream.flush() + } +} + +#[derive(Clone)] +enum StreamKind { + Indirect(IndirectStream), +} + +impl StreamKind { + fn stream(&self) -> &S { + match *self { + StreamKind::Indirect(ref s) => &s.stream + } + } + + fn mut_stream(&mut self) -> &mut S { + match *self { + StreamKind::Indirect(ref mut s) => &mut s.stream + } + } + + fn ssl(&self) -> &Ssl { + match *self { + StreamKind::Indirect(ref s) => &s.ssl + } + } +} + /// A stream wrapper which handles SSL encryption for an underlying stream. #[derive(Clone)] pub struct SslStream { - stream: S, - ssl: Arc, - buf: Vec + kind: StreamKind, } impl SslStream { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result> { + let kind = match self.kind { + StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) + }; Ok(SslStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - buf: self.buf.clone(), + kind: kind }) } } impl fmt::Debug for SslStream where S: fmt::Debug { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl) + write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl()) } } impl SslStream { - fn new_base(ssl: T, stream: S) -> Result, SslError> { - let ssl = try!(ssl.into_ssl()); + pub fn new_client(ssl: T, stream: S) -> Result, SslError> { + let stream = try!(IndirectStream::new_client(ssl, stream)); Ok(SslStream { - stream: stream, - ssl: Arc::new(ssl), - // Maximum TLS record size is 16k - buf: iter::repeat(0).take(16 * 1024).collect(), + kind: StreamKind::Indirect(stream) }) } - pub fn new_client(ssl: T, stream: S) -> Result, SslError> { - let mut ssl = try!(SslStream::new_base(ssl, stream)); - try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); - Ok(ssl) - } - pub fn new_server(ssl: T, stream: S) -> Result, SslError> { - let mut ssl = try!(SslStream::new_base(ssl, stream)); - try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); - Ok(ssl) + let stream = try!(IndirectStream::new_server(ssl, stream)); + Ok(SslStream { + kind: StreamKind::Indirect(stream) + }) } /// # Deprecated @@ -811,12 +939,12 @@ impl SslStream { /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.kind.stream() } /// Return the certificate of the peer pub fn get_peer_certificate(&self) -> Option { - self.ssl.get_peer_certificate() + self.kind.ssl().get_peer_certificate() } /// Returns a mutable reference to the underlying stream. @@ -826,46 +954,14 @@ impl SslStream { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { - &mut self.stream - } - - fn in_retry_wrapper(&mut self, mut blk: F) - -> Result where F: FnMut(&Ssl) -> c_int { - loop { - let ret = blk(&self.ssl); - if ret > 0 { - return Ok(ret); - } - - let e = self.ssl.get_error(ret); - match e { - LibSslError::ErrorWantRead => { - try_ssl_stream!(self.flush()); - let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); - if len == 0 { - self.ssl.get_rbio().set_eof(true); - } else { - try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); - } - } - LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } - LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), - LibSslError::ErrorSsl => return Err(SslError::get()), - LibSslError::ErrorSyscall if ret == 0 => return Ok(0), - err => panic!("unexpected error {:?} with ret {}", err, ret), - } - } - } - - fn write_through(&mut self) -> io::Result<()> { - io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) + self.kind.mut_stream() } /// Get the compression currently in use. The result will be /// either None, indicating no compression is in use, or a string /// with the compression name. pub fn get_compression(&self) -> Option { - let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; + let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) }; if ptr == ptr::null() { return None; } @@ -886,43 +982,34 @@ impl SslStream { /// This method needs the `npn` feature. #[cfg(feature = "npn")] pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { - self.ssl.get_selected_npn_protocol() + self.kind.ssl().get_selected_npn_protocol() } /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { - self.ssl.pending() + self.kind.ssl().pending() } } impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { - Ok(len) => Ok(len as usize), - Err(SslSessionClosed) => Ok(0), - Err(StreamError(e)) => Err(e), - Err(e @ OpenSslErrors(_)) => { - Err(io::Error::new(io::ErrorKind::Other, e)) - } + match self.kind { + StreamKind::Indirect(ref mut s) => s.read(buf) } } } impl Write for SslStream { fn write(&mut self, buf: &[u8]) -> io::Result { - let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { - Ok(len) => len as usize, - Err(SslSessionClosed) => 0, - Err(StreamError(e)) => return Err(e), - Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - }; - try!(self.write_through()); - Ok(count) + match self.kind { + StreamKind::Indirect(ref mut s) => s.write(buf) + } } fn flush(&mut self) -> io::Result<()> { - try!(self.write_through()); - self.stream.flush() + match self.kind { + StreamKind::Indirect(ref mut s) => s.flush() + } } }