Loading openssl/src/ssl/mod.rs +165 −78 Original line number Diff line number Diff line Loading @@ -5,7 +5,6 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; use std::iter; use std::mem; use std::net; use std::path::Path; Loading Loading @@ -740,52 +739,181 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } struct IndirectStream<S> { stream: S, ssl: Arc<Ssl>, // Max TLS record size is 16k buf: Box<[u8; 16 * 1024]>, } impl<S: Clone> Clone for IndirectStream<S> { fn clone(&self) -> IndirectStream<S> { IndirectStream { stream: self.stream.clone(), ssl: self.ssl.clone(), buf: Box::new(*self.buf) } } } impl IndirectStream<net::TcpStream> { fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> { Ok(IndirectStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), buf: Box::new(*self.buf) }) } } impl<S: Read+Write> IndirectStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), buf: Box::new([0; 16 * 1024]), }) } fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); Ok(ssl) } fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); Ok(ssl) } fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int { loop { let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } let e = self.ssl.get_error(ret); match e { LibSslError::ErrorWantRead => { try_ssl_stream!(self.flush()); let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); if len == 0 { self.ssl.get_rbio().set_eof(true); } else { try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); } } LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), LibSslError::ErrorSsl => return Err(SslError::get()), LibSslError::ErrorSyscall if ret == 0 => return Ok(0), err => panic!("unexpected error {:?} with ret {}", err, ret), } } } fn write_through(&mut self) -> io::Result<()> { io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) } } impl<S: Read+Write> Read for IndirectStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Ok(len as usize), Err(SslSessionClosed) => Ok(0), Err(StreamError(e)) => Err(e), Err(e @ OpenSslErrors(_)) => { Err(io::Error::new(io::ErrorKind::Other, e)) } } } } impl<S: Read+Write> Write for IndirectStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { Ok(len) => len as usize, Err(SslSessionClosed) => 0, Err(StreamError(e)) => return Err(e), Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), }; try!(self.write_through()); Ok(count) } fn flush(&mut self) -> io::Result<()> { try!(self.write_through()); self.stream.flush() } } #[derive(Clone)] enum StreamKind<S> { Indirect(IndirectStream<S>), } impl<S> StreamKind<S> { fn stream(&self) -> &S { match *self { StreamKind::Indirect(ref s) => &s.stream } } fn mut_stream(&mut self) -> &mut S { match *self { StreamKind::Indirect(ref mut s) => &mut s.stream } } fn ssl(&self) -> &Ssl { match *self { StreamKind::Indirect(ref s) => &s.ssl } } } /// A stream wrapper which handles SSL encryption for an underlying stream. #[derive(Clone)] pub struct SslStream<S> { stream: S, ssl: Arc<Ssl>, buf: Vec<u8> kind: StreamKind<S>, } impl SslStream<net::TcpStream> { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { let kind = match self.kind { StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) }; Ok(SslStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), buf: self.buf.clone(), kind: kind }) } } impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl) write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl()) } } impl<S: Read+Write> SslStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let stream = try!(IndirectStream::new_client(ssl, stream)); Ok(SslStream { stream: stream, ssl: Arc::new(ssl), // Maximum TLS record size is 16k buf: iter::repeat(0).take(16 * 1024).collect(), kind: StreamKind::Indirect(stream) }) } pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let mut ssl = try!(SslStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); Ok(ssl) } pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let mut ssl = try!(SslStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); Ok(ssl) let stream = try!(IndirectStream::new_server(ssl, stream)); Ok(SslStream { kind: StreamKind::Indirect(stream) }) } /// # Deprecated Loading @@ -811,12 +939,12 @@ impl<S: Read+Write> SslStream<S> { /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.stream self.kind.stream() } /// Return the certificate of the peer pub fn get_peer_certificate(&self) -> Option<X509> { self.ssl.get_peer_certificate() self.kind.ssl().get_peer_certificate() } /// Returns a mutable reference to the underlying stream. Loading @@ -826,46 +954,14 @@ impl<S: Read+Write> SslStream<S> { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { &mut self.stream } fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int { loop { let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } let e = self.ssl.get_error(ret); match e { LibSslError::ErrorWantRead => { try_ssl_stream!(self.flush()); let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); if len == 0 { self.ssl.get_rbio().set_eof(true); } else { try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); } } LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), LibSslError::ErrorSsl => return Err(SslError::get()), LibSslError::ErrorSyscall if ret == 0 => return Ok(0), err => panic!("unexpected error {:?} with ret {}", err, ret), } } } fn write_through(&mut self) -> io::Result<()> { io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) self.kind.mut_stream() } /// Get the compression currently in use. The result will be /// either None, indicating no compression is in use, or a string /// with the compression name. pub fn get_compression(&self) -> Option<String> { let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) }; if ptr == ptr::null() { return None; } Loading @@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> { /// This method needs the `npn` feature. #[cfg(feature = "npn")] pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { self.ssl.get_selected_npn_protocol() self.kind.ssl().get_selected_npn_protocol() } /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { self.ssl.pending() self.kind.ssl().pending() } } impl<S: Read+Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Ok(len as usize), Err(SslSessionClosed) => Ok(0), Err(StreamError(e)) => Err(e), Err(e @ OpenSslErrors(_)) => { Err(io::Error::new(io::ErrorKind::Other, e)) } match self.kind { StreamKind::Indirect(ref mut s) => s.read(buf) } } } impl<S: Read+Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { Ok(len) => len as usize, Err(SslSessionClosed) => 0, Err(StreamError(e)) => return Err(e), Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), }; try!(self.write_through()); Ok(count) match self.kind { StreamKind::Indirect(ref mut s) => s.write(buf) } } fn flush(&mut self) -> io::Result<()> { try!(self.write_through()); self.stream.flush() match self.kind { StreamKind::Indirect(ref mut s) => s.flush() } } } Loading Loading
openssl/src/ssl/mod.rs +165 −78 Original line number Diff line number Diff line Loading @@ -5,7 +5,6 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; use std::iter; use std::mem; use std::net; use std::path::Path; Loading Loading @@ -740,52 +739,181 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } struct IndirectStream<S> { stream: S, ssl: Arc<Ssl>, // Max TLS record size is 16k buf: Box<[u8; 16 * 1024]>, } impl<S: Clone> Clone for IndirectStream<S> { fn clone(&self) -> IndirectStream<S> { IndirectStream { stream: self.stream.clone(), ssl: self.ssl.clone(), buf: Box::new(*self.buf) } } } impl IndirectStream<net::TcpStream> { fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> { Ok(IndirectStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), buf: Box::new(*self.buf) }) } } impl<S: Read+Write> IndirectStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), buf: Box::new([0; 16 * 1024]), }) } fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); Ok(ssl) } fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); Ok(ssl) } fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int { loop { let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } let e = self.ssl.get_error(ret); match e { LibSslError::ErrorWantRead => { try_ssl_stream!(self.flush()); let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); if len == 0 { self.ssl.get_rbio().set_eof(true); } else { try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); } } LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), LibSslError::ErrorSsl => return Err(SslError::get()), LibSslError::ErrorSyscall if ret == 0 => return Ok(0), err => panic!("unexpected error {:?} with ret {}", err, ret), } } } fn write_through(&mut self) -> io::Result<()> { io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) } } impl<S: Read+Write> Read for IndirectStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Ok(len as usize), Err(SslSessionClosed) => Ok(0), Err(StreamError(e)) => Err(e), Err(e @ OpenSslErrors(_)) => { Err(io::Error::new(io::ErrorKind::Other, e)) } } } } impl<S: Read+Write> Write for IndirectStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { Ok(len) => len as usize, Err(SslSessionClosed) => 0, Err(StreamError(e)) => return Err(e), Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), }; try!(self.write_through()); Ok(count) } fn flush(&mut self) -> io::Result<()> { try!(self.write_through()); self.stream.flush() } } #[derive(Clone)] enum StreamKind<S> { Indirect(IndirectStream<S>), } impl<S> StreamKind<S> { fn stream(&self) -> &S { match *self { StreamKind::Indirect(ref s) => &s.stream } } fn mut_stream(&mut self) -> &mut S { match *self { StreamKind::Indirect(ref mut s) => &mut s.stream } } fn ssl(&self) -> &Ssl { match *self { StreamKind::Indirect(ref s) => &s.ssl } } } /// A stream wrapper which handles SSL encryption for an underlying stream. #[derive(Clone)] pub struct SslStream<S> { stream: S, ssl: Arc<Ssl>, buf: Vec<u8> kind: StreamKind<S>, } impl SslStream<net::TcpStream> { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { let kind = match self.kind { StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) }; Ok(SslStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), buf: self.buf.clone(), kind: kind }) } } impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl) write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl()) } } impl<S: Read+Write> SslStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let stream = try!(IndirectStream::new_client(ssl, stream)); Ok(SslStream { stream: stream, ssl: Arc::new(ssl), // Maximum TLS record size is 16k buf: iter::repeat(0).take(16 * 1024).collect(), kind: StreamKind::Indirect(stream) }) } pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let mut ssl = try!(SslStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); Ok(ssl) } pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let mut ssl = try!(SslStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); Ok(ssl) let stream = try!(IndirectStream::new_server(ssl, stream)); Ok(SslStream { kind: StreamKind::Indirect(stream) }) } /// # Deprecated Loading @@ -811,12 +939,12 @@ impl<S: Read+Write> SslStream<S> { /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.stream self.kind.stream() } /// Return the certificate of the peer pub fn get_peer_certificate(&self) -> Option<X509> { self.ssl.get_peer_certificate() self.kind.ssl().get_peer_certificate() } /// Returns a mutable reference to the underlying stream. Loading @@ -826,46 +954,14 @@ impl<S: Read+Write> SslStream<S> { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { &mut self.stream } fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int { loop { let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } let e = self.ssl.get_error(ret); match e { LibSslError::ErrorWantRead => { try_ssl_stream!(self.flush()); let len = try_ssl_stream!(self.stream.read(&mut self.buf[..])); if len == 0 { self.ssl.get_rbio().set_eof(true); } else { try_ssl_stream!(self.ssl.get_rbio().write_all(&self.buf[..len])); } } LibSslError::ErrorWantWrite => { try_ssl_stream!(self.flush()) } LibSslError::ErrorZeroReturn => return Err(SslSessionClosed), LibSslError::ErrorSsl => return Err(SslError::get()), LibSslError::ErrorSyscall if ret == 0 => return Ok(0), err => panic!("unexpected error {:?} with ret {}", err, ret), } } } fn write_through(&mut self) -> io::Result<()> { io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) self.kind.mut_stream() } /// Get the compression currently in use. The result will be /// either None, indicating no compression is in use, or a string /// with the compression name. pub fn get_compression(&self) -> Option<String> { let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) }; if ptr == ptr::null() { return None; } Loading @@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> { /// This method needs the `npn` feature. #[cfg(feature = "npn")] pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { self.ssl.get_selected_npn_protocol() self.kind.ssl().get_selected_npn_protocol() } /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { self.ssl.pending() self.kind.ssl().pending() } } impl<S: Read+Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Ok(len as usize), Err(SslSessionClosed) => Ok(0), Err(StreamError(e)) => Err(e), Err(e @ OpenSslErrors(_)) => { Err(io::Error::new(io::ErrorKind::Other, e)) } match self.kind { StreamKind::Indirect(ref mut s) => s.read(buf) } } } impl<S: Read+Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { Ok(len) => len as usize, Err(SslSessionClosed) => 0, Err(StreamError(e)) => return Err(e), Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), }; try!(self.write_through()); Ok(count) match self.kind { StreamKind::Indirect(ref mut s) => s.write(buf) } } fn flush(&mut self) -> io::Result<()> { try!(self.write_through()); self.stream.flush() match self.kind { StreamKind::Indirect(ref mut s) => s.flush() } } } Loading