Loading openssl-sys/src/lib.rs +1 −0 Original line number Diff line number Diff line Loading @@ -741,6 +741,7 @@ extern "C" { pub fn SSL_get_wbio(ssl: *mut SSL) -> *mut BIO; pub fn SSL_accept(ssl: *mut SSL) -> c_int; pub fn SSL_connect(ssl: *mut SSL) -> c_int; pub fn SSL_do_handshake(ssl: *mut SSL) -> c_int; pub fn SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long; pub fn SSL_get_error(ssl: *mut SSL, ret: c_int) -> c_int; Loading openssl/src/ssl/mod.rs +112 −8 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; use std::error as stderror; use std::mem; use std::str; use std::path::Path; Loading Loading @@ -832,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } fn handshake(&self) -> c_int { unsafe { ffi::SSL_do_handshake(self.ssl) } } fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } Loading Loading @@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { let ssl = try!(ssl.into_ssl()); pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, HandshakeError<S>>{ let ssl = try!(ssl.into_ssl().map_err(|e| { HandshakeError::Failure(Error::Ssl(e)) })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), err => Err(err) e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { let ssl = try!(ssl.into_ssl()); pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, HandshakeError<S>> { let ssl = try!(ssl.into_ssl().map_err(|e| { HandshakeError::Failure(Error::Ssl(e)) })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), err => Err(err) e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } err => Err(HandshakeError::Failure(err)), } } } Loading Loading @@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> { } } /// An error or intermediate state after a TLS handshake attempt. #[derive(Debug)] pub enum HandshakeError<S> { /// The handshake failed. Failure(Error), /// The handshake was interrupted midway through. Interrupted(MidHandshakeSslStream<S>), } impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { fn description(&self) -> &str { match *self { HandshakeError::Failure(ref e) => e.description(), HandshakeError::Interrupted(ref e) => e.error.description(), } } fn cause(&self) -> Option<&stderror::Error> { match *self { HandshakeError::Failure(ref e) => Some(e), HandshakeError::Interrupted(ref e) => Some(&e.error), } } } impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { try!(f.write_str(stderror::Error::description(self))); if let Some(e) = stderror::Error::cause(self) { try!(write!(f, ": {}", e)); } Ok(()) } } /// An SSL stream midway through the handshake process. #[derive(Debug)] pub struct MidHandshakeSslStream<S> { stream: SslStream<S>, error: Error, } impl<S> MidHandshakeSslStream<S> { /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &S { self.stream.get_ref() } /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut S { self.stream.get_mut() } /// Returns a shared reference to the `SslContext` of the stream. pub fn ssl(&self) -> &Ssl { self.stream.ssl() } /// Returns the underlying error which interrupted this handshake. pub fn error(&self) -> &Error { &self.error } /// Restarts the handshake process. pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { let ret = self.stream.ssl.handshake(); if ret > 0 { Ok(self.stream) } else { match self.stream.make_error(ret) { e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { self.error = e; Err(HandshakeError::Interrupted(self)) } err => Err(HandshakeError::Failure(err)), } } } } impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); Loading openssl/src/ssl/tests/mod.rs +24 −11 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ use crypto::hash::Type::SHA256; use ssl; use ssl::SSL_VERIFY_PEER; use ssl::SslMethod::Sslv23; use ssl::SslMethod; use ssl::{SslMethod, HandshakeError}; use ssl::error::Error; use ssl::{SslContext, SslStream}; use x509::X509StoreContext; Loading Loading @@ -133,6 +133,7 @@ impl Drop for Server { } #[cfg(feature = "dtlsv1")] #[derive(Debug)] struct UdpConnected(UdpSocket); #[cfg(feature = "dtlsv1")] Loading Loading @@ -846,10 +847,10 @@ fn test_sslv2_connect_failure() { .unwrap(); } fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool { unsafe { let mut set: select::fd_set = mem::zeroed(); select::fd_set(&mut set, stream.get_ref()); select::fd_set(&mut set, stream); let write = if read { 0 as *mut _ Loading @@ -861,7 +862,19 @@ fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { } else { &mut set as *mut _ }; select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms).unwrap() select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap() } } fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>) -> SslStream<TcpStream> { match res { Ok(s) => s, Err(HandshakeError::Interrupted(s)) => { wait_io(s.get_ref(), true, 1_000); handshake(s.handshake()) } Err(err) => panic!("error on handshake {:?}", err), } } Loading @@ -870,7 +883,7 @@ fn test_write_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); let mut stream = SslStream::connect(&cx, stream).unwrap(); let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { Loading @@ -886,10 +899,10 @@ fn test_write_nonblocking() { break; } Err(Error::WantRead(_)) => { assert!(wait_io(&stream, true, 1000)); assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(_)) => { assert!(wait_io(&stream, false, 1000)); assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); Loading @@ -907,7 +920,7 @@ fn test_read_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); let mut stream = SslStream::connect(&cx, stream).unwrap(); let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { Loading @@ -924,10 +937,10 @@ fn test_read_nonblocking() { break; } Err(Error::WantRead(..)) => { assert!(wait_io(&stream, true, 1000)); assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(..)) => { assert!(wait_io(&stream, false, 1000)); assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); Loading @@ -944,7 +957,7 @@ fn test_read_nonblocking() { n } Err(Error::WantRead(..)) => { assert!(wait_io(&stream, true, 3000)); assert!(wait_io(stream.get_ref(), true, 3000)); // Second read should return application data. stream.read(&mut input_buffer).unwrap() } Loading Loading
openssl-sys/src/lib.rs +1 −0 Original line number Diff line number Diff line Loading @@ -741,6 +741,7 @@ extern "C" { pub fn SSL_get_wbio(ssl: *mut SSL) -> *mut BIO; pub fn SSL_accept(ssl: *mut SSL) -> c_int; pub fn SSL_connect(ssl: *mut SSL) -> c_int; pub fn SSL_do_handshake(ssl: *mut SSL) -> c_int; pub fn SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long; pub fn SSL_get_error(ssl: *mut SSL, ret: c_int) -> c_int; Loading
openssl/src/ssl/mod.rs +112 −8 Original line number Diff line number Diff line Loading @@ -5,6 +5,7 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; use std::error as stderror; use std::mem; use std::str; use std::path::Path; Loading Loading @@ -832,6 +833,10 @@ impl Ssl { unsafe { ffi::SSL_accept(self.ssl) } } fn handshake(&self) -> c_int { unsafe { ffi::SSL_do_handshake(self.ssl) } } fn read(&self, buf: &mut [u8]) -> c_int { let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) } Loading Loading @@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> { } /// Creates an SSL/TLS client operating over the provided stream. pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { let ssl = try!(ssl.into_ssl()); pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, HandshakeError<S>>{ let ssl = try!(ssl.into_ssl().map_err(|e| { HandshakeError::Failure(Error::Ssl(e)) })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.connect(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), err => Err(err) e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } err => Err(HandshakeError::Failure(err)), } } } /// Creates an SSL/TLS server operating over the provided stream. pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> { let ssl = try!(ssl.into_ssl()); pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, HandshakeError<S>> { let ssl = try!(ssl.into_ssl().map_err(|e| { HandshakeError::Failure(Error::Ssl(e)) })); let mut stream = Self::new_base(ssl, stream); let ret = stream.ssl.accept(); if ret > 0 { Ok(stream) } else { match stream.make_error(ret) { Error::WantRead(..) | Error::WantWrite(..) => Ok(stream), err => Err(err) e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } err => Err(HandshakeError::Failure(err)), } } } Loading Loading @@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> { } } /// An error or intermediate state after a TLS handshake attempt. #[derive(Debug)] pub enum HandshakeError<S> { /// The handshake failed. Failure(Error), /// The handshake was interrupted midway through. Interrupted(MidHandshakeSslStream<S>), } impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> { fn description(&self) -> &str { match *self { HandshakeError::Failure(ref e) => e.description(), HandshakeError::Interrupted(ref e) => e.error.description(), } } fn cause(&self) -> Option<&stderror::Error> { match *self { HandshakeError::Failure(ref e) => Some(e), HandshakeError::Interrupted(ref e) => Some(&e.error), } } } impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { try!(f.write_str(stderror::Error::description(self))); if let Some(e) = stderror::Error::cause(self) { try!(write!(f, ": {}", e)); } Ok(()) } } /// An SSL stream midway through the handshake process. #[derive(Debug)] pub struct MidHandshakeSslStream<S> { stream: SslStream<S>, error: Error, } impl<S> MidHandshakeSslStream<S> { /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &S { self.stream.get_ref() } /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut S { self.stream.get_mut() } /// Returns a shared reference to the `SslContext` of the stream. pub fn ssl(&self) -> &Ssl { self.stream.ssl() } /// Returns the underlying error which interrupted this handshake. pub fn error(&self) -> &Error { &self.error } /// Restarts the handshake process. pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> { let ret = self.stream.ssl.handshake(); if ret > 0 { Ok(self.stream) } else { match self.stream.make_error(ret) { e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { self.error = e; Err(HandshakeError::Interrupted(self)) } err => Err(HandshakeError::Failure(err)), } } } } impl<S> SslStream<S> { fn make_error(&mut self, ret: c_int) -> Error { self.check_panic(); Loading
openssl/src/ssl/tests/mod.rs +24 −11 Original line number Diff line number Diff line Loading @@ -17,7 +17,7 @@ use crypto::hash::Type::SHA256; use ssl; use ssl::SSL_VERIFY_PEER; use ssl::SslMethod::Sslv23; use ssl::SslMethod; use ssl::{SslMethod, HandshakeError}; use ssl::error::Error; use ssl::{SslContext, SslStream}; use x509::X509StoreContext; Loading Loading @@ -133,6 +133,7 @@ impl Drop for Server { } #[cfg(feature = "dtlsv1")] #[derive(Debug)] struct UdpConnected(UdpSocket); #[cfg(feature = "dtlsv1")] Loading Loading @@ -846,10 +847,10 @@ fn test_sslv2_connect_failure() { .unwrap(); } fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool { unsafe { let mut set: select::fd_set = mem::zeroed(); select::fd_set(&mut set, stream.get_ref()); select::fd_set(&mut set, stream); let write = if read { 0 as *mut _ Loading @@ -861,7 +862,19 @@ fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool { } else { &mut set as *mut _ }; select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms).unwrap() select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap() } } fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>) -> SslStream<TcpStream> { match res { Ok(s) => s, Err(HandshakeError::Interrupted(s)) => { wait_io(s.get_ref(), true, 1_000); handshake(s.handshake()) } Err(err) => panic!("error on handshake {:?}", err), } } Loading @@ -870,7 +883,7 @@ fn test_write_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); let mut stream = SslStream::connect(&cx, stream).unwrap(); let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { Loading @@ -886,10 +899,10 @@ fn test_write_nonblocking() { break; } Err(Error::WantRead(_)) => { assert!(wait_io(&stream, true, 1000)); assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(_)) => { assert!(wait_io(&stream, false, 1000)); assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); Loading @@ -907,7 +920,7 @@ fn test_read_nonblocking() { let (_s, stream) = Server::new(); stream.set_nonblocking(true).unwrap(); let cx = SslContext::new(Sslv23).unwrap(); let mut stream = SslStream::connect(&cx, stream).unwrap(); let mut stream = handshake(SslStream::connect(&cx, stream)); let mut iterations = 0; loop { Loading @@ -924,10 +937,10 @@ fn test_read_nonblocking() { break; } Err(Error::WantRead(..)) => { assert!(wait_io(&stream, true, 1000)); assert!(wait_io(stream.get_ref(), true, 1000)); } Err(Error::WantWrite(..)) => { assert!(wait_io(&stream, false, 1000)); assert!(wait_io(stream.get_ref(), false, 1000)); } Err(other) => { panic!("Unexpected SSL Error: {:?}", other); Loading @@ -944,7 +957,7 @@ fn test_read_nonblocking() { n } Err(Error::WantRead(..)) => { assert!(wait_io(&stream, true, 3000)); assert!(wait_io(stream.get_ref(), true, 3000)); // Second read should return application data. stream.read(&mut input_buffer).unwrap() } Loading