Loading openssl/src/ssl/mod.rs +169 −12 Original line number Diff line number Diff line Loading @@ -603,11 +603,6 @@ impl Ssl { return Err(SslError::get()); } let ssl = Ssl { ssl: ssl }; let rbio = try!(MemBio::new()); let wbio = try!(MemBio::new()); unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(ssl) } Loading Loading @@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> { impl<S: Read+Write> IndirectStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let rbio = try!(MemBio::new()); let wbio = try!(MemBio::new()); unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), Loading Loading @@ -852,27 +853,139 @@ impl<S: Read+Write> Write for IndirectStream<S> { } } #[derive(Clone)] struct DirectStream<S> { stream: S, ssl: Arc<Ssl>, } impl DirectStream<net::TcpStream> { fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> { Ok(DirectStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), }) } } impl<S> DirectStream<S> { fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { unsafe { let bio = ffi::BIO_new_socket(sock, 0); if bio == ptr::null_mut() { return Err(SslError::get()); } ffi::SSL_set_bio(ssl.ssl, bio, bio); } Ok(DirectStream { stream: stream, ssl: Arc::new(ssl), }) } fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { let ssl = try!(DirectStream::new_base(ssl, stream, sock)); let ret = ssl.ssl.connect(); if ret > 0 { Ok(ssl) } else { Err(ssl.make_error(ret)) } } fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { let ssl = try!(DirectStream::new_base(ssl, stream, sock)); let ret = ssl.ssl.accept(); if ret > 0 { Ok(ssl) } else { Err(ssl.make_error(ret)) } } fn make_error(&self, ret: c_int) -> SslError { match self.ssl.get_error(ret) { LibSslError::ErrorSsl => SslError::get(), LibSslError::ErrorSyscall => { let err = SslError::get(); let count = match err { SslError::OpenSslErrors(ref v) => v.len(), _ => unreachable!(), }; if count == 0 { if ret == 0 { SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, "unexpected EOF observed")) } else { SslError::StreamError(io::Error::last_os_error()) } } else { err } } err => panic!("unexpected error {:?} with ret {}", err, ret), } } } impl<S> Read for DirectStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { let ret = self.ssl.read(buf); if ret >= 0 { return Ok(ret as usize); } match self.make_error(ret) { SslError::StreamError(e) => Err(e), e => Err(io::Error::new(io::ErrorKind::Other, e)), } } } impl<S: Write> Write for DirectStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let ret = self.ssl.write(buf); if ret > 0 { return Ok(ret as usize); } match self.make_error(ret) { SslError::StreamError(e) => Err(e), e => Err(io::Error::new(io::ErrorKind::Other, e)), } } fn flush(&mut self) -> io::Result<()> { self.stream.flush() } } #[derive(Clone)] enum StreamKind<S> { Indirect(IndirectStream<S>), Direct(DirectStream<S>), } impl<S> StreamKind<S> { fn stream(&self) -> &S { match *self { StreamKind::Indirect(ref s) => &s.stream StreamKind::Indirect(ref s) => &s.stream, StreamKind::Direct(ref s) => &s.stream, } } fn mut_stream(&mut self) -> &mut S { match *self { StreamKind::Indirect(ref mut s) => &mut s.stream StreamKind::Indirect(ref mut s) => &mut s.stream, StreamKind::Direct(ref mut s) => &mut s.stream, } } fn ssl(&self) -> &Ssl { match *self { StreamKind::Indirect(ref s) => &s.ssl StreamKind::Indirect(ref s) => &s.ssl, StreamKind::Direct(ref s) => &s.ssl, } } } Loading @@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { let kind = match self.kind { StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())), StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone())) }; Ok(SslStream { kind: kind Loading @@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { } } #[cfg(unix)] impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> { pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let fd = stream.as_raw_fd() as c_int; let stream = try!(DirectStream::new_client(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let fd = stream.as_raw_fd() as c_int; let stream = try!(DirectStream::new_server(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } } #[cfg(windows)] impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> { pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let fd = stream.as_raw_socket() as c_int; let stream = try!(DirectStream::new_client(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let fd = stream.as_raw_socket() as c_int; let stream = try!(DirectStream::new_server(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } } impl<S: Read+Write> SslStream<S> { pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let stream = try!(IndirectStream::new_client(ssl, stream)); Loading Loading @@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> { impl<S: Read+Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.kind { StreamKind::Indirect(ref mut s) => s.read(buf) StreamKind::Indirect(ref mut s) => s.read(buf), StreamKind::Direct(ref mut s) => s.read(buf), } } } Loading @@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> { impl<S: Read+Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { match self.kind { StreamKind::Indirect(ref mut s) => s.write(buf) StreamKind::Indirect(ref mut s) => s.write(buf), StreamKind::Direct(ref mut s) => s.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self.kind { StreamKind::Indirect(ref mut s) => s.flush() StreamKind::Indirect(ref mut s) => s.flush(), StreamKind::Direct(ref mut s) => s.flush(), } } } Loading openssl/src/ssl/tests.rs +18 −1 Original line number Diff line number Diff line Loading @@ -317,8 +317,17 @@ fn test_write() { stream.flush().unwrap(); } #[test] fn test_write_direct() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); stream.write_all("hello".as_bytes()).unwrap(); stream.flush().unwrap(); stream.write_all(" there".as_bytes()).unwrap(); stream.flush().unwrap(); } run_test!(get_peer_certificate, |method, stream| { //let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap(); let cert = stream.get_peer_certificate().unwrap(); let fingerprint = cert.fingerprint(SHA256).unwrap(); Loading Loading @@ -349,6 +358,14 @@ fn test_read() { io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); } #[test] fn test_read_direct() { let tcp = TcpStream::connect("127.0.0.1:15418").unwrap(); let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap(); stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap(); stream.flush().unwrap(); io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); } #[test] fn test_pending() { Loading Loading
openssl/src/ssl/mod.rs +169 −12 Original line number Diff line number Diff line Loading @@ -603,11 +603,6 @@ impl Ssl { return Err(SslError::get()); } let ssl = Ssl { ssl: ssl }; let rbio = try!(MemBio::new()); let wbio = try!(MemBio::new()); unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(ssl) } Loading Loading @@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> { impl<S: Read+Write> IndirectStream<S> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let rbio = try!(MemBio::new()); let wbio = try!(MemBio::new()); unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), Loading Loading @@ -852,27 +853,139 @@ impl<S: Read+Write> Write for IndirectStream<S> { } } #[derive(Clone)] struct DirectStream<S> { stream: S, ssl: Arc<Ssl>, } impl DirectStream<net::TcpStream> { fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> { Ok(DirectStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), }) } } impl<S> DirectStream<S> { fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { unsafe { let bio = ffi::BIO_new_socket(sock, 0); if bio == ptr::null_mut() { return Err(SslError::get()); } ffi::SSL_set_bio(ssl.ssl, bio, bio); } Ok(DirectStream { stream: stream, ssl: Arc::new(ssl), }) } fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { let ssl = try!(DirectStream::new_base(ssl, stream, sock)); let ret = ssl.ssl.connect(); if ret > 0 { Ok(ssl) } else { Err(ssl.make_error(ret)) } } fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> { let ssl = try!(DirectStream::new_base(ssl, stream, sock)); let ret = ssl.ssl.accept(); if ret > 0 { Ok(ssl) } else { Err(ssl.make_error(ret)) } } fn make_error(&self, ret: c_int) -> SslError { match self.ssl.get_error(ret) { LibSslError::ErrorSsl => SslError::get(), LibSslError::ErrorSyscall => { let err = SslError::get(); let count = match err { SslError::OpenSslErrors(ref v) => v.len(), _ => unreachable!(), }; if count == 0 { if ret == 0 { SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, "unexpected EOF observed")) } else { SslError::StreamError(io::Error::last_os_error()) } } else { err } } err => panic!("unexpected error {:?} with ret {}", err, ret), } } } impl<S> Read for DirectStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { let ret = self.ssl.read(buf); if ret >= 0 { return Ok(ret as usize); } match self.make_error(ret) { SslError::StreamError(e) => Err(e), e => Err(io::Error::new(io::ErrorKind::Other, e)), } } } impl<S: Write> Write for DirectStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { let ret = self.ssl.write(buf); if ret > 0 { return Ok(ret as usize); } match self.make_error(ret) { SslError::StreamError(e) => Err(e), e => Err(io::Error::new(io::ErrorKind::Other, e)), } } fn flush(&mut self) -> io::Result<()> { self.stream.flush() } } #[derive(Clone)] enum StreamKind<S> { Indirect(IndirectStream<S>), Direct(DirectStream<S>), } impl<S> StreamKind<S> { fn stream(&self) -> &S { match *self { StreamKind::Indirect(ref s) => &s.stream StreamKind::Indirect(ref s) => &s.stream, StreamKind::Direct(ref s) => &s.stream, } } fn mut_stream(&mut self) -> &mut S { match *self { StreamKind::Indirect(ref mut s) => &mut s.stream StreamKind::Indirect(ref mut s) => &mut s.stream, StreamKind::Direct(ref mut s) => &mut s.stream, } } fn ssl(&self) -> &Ssl { match *self { StreamKind::Indirect(ref s) => &s.ssl StreamKind::Indirect(ref s) => &s.ssl, StreamKind::Direct(ref s) => &s.ssl, } } } Loading @@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { let kind = match self.kind { StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())), StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone())) }; Ok(SslStream { kind: kind Loading @@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { } } #[cfg(unix)] impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> { pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let fd = stream.as_raw_fd() as c_int; let stream = try!(DirectStream::new_client(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let ssl = try!(ssl.into_ssl()); let fd = stream.as_raw_fd() as c_int; let stream = try!(DirectStream::new_server(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } } #[cfg(windows)] impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> { pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let fd = stream.as_raw_socket() as c_int; let stream = try!(DirectStream::new_client(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let fd = stream.as_raw_socket() as c_int; let stream = try!(DirectStream::new_server(ssl, stream, fd)); Ok(SslStream { kind: StreamKind::Direct(stream) }) } } impl<S: Read+Write> SslStream<S> { pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { let stream = try!(IndirectStream::new_client(ssl, stream)); Loading Loading @@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> { impl<S: Read+Write> Read for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { match self.kind { StreamKind::Indirect(ref mut s) => s.read(buf) StreamKind::Indirect(ref mut s) => s.read(buf), StreamKind::Direct(ref mut s) => s.read(buf), } } } Loading @@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> { impl<S: Read+Write> Write for SslStream<S> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { match self.kind { StreamKind::Indirect(ref mut s) => s.write(buf) StreamKind::Indirect(ref mut s) => s.write(buf), StreamKind::Direct(ref mut s) => s.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self.kind { StreamKind::Indirect(ref mut s) => s.flush() StreamKind::Indirect(ref mut s) => s.flush(), StreamKind::Direct(ref mut s) => s.flush(), } } } Loading
openssl/src/ssl/tests.rs +18 −1 Original line number Diff line number Diff line Loading @@ -317,8 +317,17 @@ fn test_write() { stream.flush().unwrap(); } #[test] fn test_write_direct() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); stream.write_all("hello".as_bytes()).unwrap(); stream.flush().unwrap(); stream.write_all(" there".as_bytes()).unwrap(); stream.flush().unwrap(); } run_test!(get_peer_certificate, |method, stream| { //let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap(); let cert = stream.get_peer_certificate().unwrap(); let fingerprint = cert.fingerprint(SHA256).unwrap(); Loading Loading @@ -349,6 +358,14 @@ fn test_read() { io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); } #[test] fn test_read_direct() { let tcp = TcpStream::connect("127.0.0.1:15418").unwrap(); let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap(); stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap(); stream.flush().unwrap(); io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); } #[test] fn test_pending() { Loading