From f99c10155993ba2d34aa36f74d3458872787316b Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sun, 4 Mar 2018 22:05:09 -0800 Subject: [PATCH] Add test for stateless connection --- openssl/src/ssl/test.rs | 121 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 2 deletions(-) diff --git a/openssl/src/ssl/test.rs b/openssl/src/ssl/test.rs index c732f3fc9..ddcb49ff5 100644 --- a/openssl/src/ssl/test.rs +++ b/openssl/src/ssl/test.rs @@ -19,8 +19,9 @@ use hash::MessageDigest; use ocsp::{OcspResponse, OcspResponseStatus}; use ssl; use ssl::{Error, HandshakeError, ShutdownResult, Ssl, SslAcceptor, SslConnector, SslContext, - SslFiletype, SslMethod, SslSessionCacheMode, SslStream, SslVerifyMode, StatusType}; -#[cfg(any(ossl110))] + SslFiletype, SslMethod, SslSessionCacheMode, SslStream, MidHandshakeSslStream, + SslVerifyMode, StatusType}; +#[cfg(any(ossl110, ossl111))] use ssl::SslVersion; use x509::{X509, X509Name, X509StoreContext, X509VerifyResult}; #[cfg(any(ossl102, ossl110))] @@ -1389,3 +1390,119 @@ fn _check_kinds() { is_send::>(); is_sync::>(); } + +#[derive(Debug)] +struct MemoryStream { + incoming: io::Cursor>, + outgoing: Vec, +} + +impl MemoryStream { + pub fn new() -> Self { Self { + incoming: io::Cursor::new(Vec::new()), + outgoing: Vec::new(), + }} + + pub fn extend_incoming(&mut self, data: &[u8]) { + self.incoming.get_mut().extend_from_slice(data); + } + + pub fn take_outgoing(&mut self) -> Outgoing { Outgoing(&mut self.outgoing) } +} + +impl Read for MemoryStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let n = self.incoming.read(buf)?; + if self.incoming.position() == self.incoming.get_ref().len() as u64 { + self.incoming.set_position(0); + self.incoming.get_mut().clear(); + } + if n == 0 { + return Err(io::Error::new(io::ErrorKind::WouldBlock, "no data available")); + } + Ok(n) + } +} + +impl Write for MemoryStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.outgoing.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { Ok(()) } +} + +pub struct Outgoing<'a>(&'a mut Vec); + +impl<'a> Drop for Outgoing<'a> { + fn drop(&mut self) { + self.0.clear(); + } +} + +impl<'a> ::std::ops::Deref for Outgoing<'a> { + type Target = [u8]; + fn deref(&self) -> &[u8] { &self.0 } +} + +impl<'a> AsRef<[u8]> for Outgoing<'a> { + fn as_ref(&self) -> &[u8] { &self.0 } +} + +fn send(from: &mut MemoryStream, to: &mut MemoryStream) { + to.extend_incoming(&from.take_outgoing()); +} + +#[test] +#[cfg(ossl111)] +fn stateless() { + use super::SslOptions; + + fn hs(stream: Result, HandshakeError>) -> Result, MidHandshakeSslStream> { + match stream { + Ok(stream) => Ok(stream), + Err(HandshakeError::WouldBlock(stream)) => Err(stream), + Err(e) => panic!("unexpected error: {:?}", e), + } + } + + // + // Setup + // + + let mut client_ctx = SslContext::builder(SslMethod::tls()).unwrap(); + client_ctx.clear_options(SslOptions::ENABLE_MIDDLEBOX_COMPAT); + let client_stream = Ssl::new(&client_ctx.build()).unwrap(); + + let mut server_ctx = SslContext::builder(SslMethod::tls()).unwrap(); + server_ctx.set_certificate_file(&Path::new("test/cert.pem"), SslFiletype::PEM) + .unwrap(); + server_ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM) + .unwrap(); + const COOKIE: &[u8] = b"chocolate chip"; + server_ctx.set_stateless_cookie_generate_cb(|_tls, buf| { buf[0..COOKIE.len()].copy_from_slice(COOKIE); Ok(COOKIE.len()) }); + server_ctx.set_stateless_cookie_verify_cb(|_tls, buf| buf == COOKIE); + let mut server_stream = ssl::SslStreamBuilder::new(Ssl::new(&server_ctx.build()).unwrap(), MemoryStream::new()); + + // + // Handshake + // + + // Initial ClientHello + let mut client_stream = hs(client_stream.connect(MemoryStream::new())).unwrap_err(); + send(client_stream.get_mut(), server_stream.get_mut()); + // HelloRetryRequest + assert!(!server_stream.stateless().unwrap()); + send(server_stream.get_mut(), client_stream.get_mut()); + // Second ClientHello + let mut client_stream = hs(client_stream.handshake()).unwrap_err(); + send(client_stream.get_mut(), server_stream.get_mut()); + // ServerHello + assert!(server_stream.stateless().unwrap()); + let mut server_stream = hs(server_stream.accept()).unwrap_err(); + send(server_stream.get_mut(), client_stream.get_mut()); + // Finished + let mut client_stream = hs(client_stream.handshake()).unwrap(); + send(client_stream.get_mut(), server_stream.get_mut()); + hs(server_stream.handshake()).unwrap(); +} -- GitLab