Commit 1373a76c authored by Steven Fackler's avatar Steven Fackler
Browse files

Implement direct IO support

parent 9b235a7b
Loading
Loading
Loading
Loading
+169 −12
Original line number Diff line number Diff line
@@ -603,11 +603,6 @@ impl Ssl {
            return Err(SslError::get());
        }
        let ssl = Ssl { ssl: ssl };

        let rbio = try!(MemBio::new());
        let wbio = try!(MemBio::new());

        unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
        Ok(ssl)
    }

@@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> {
impl<S: Read+Write> IndirectStream<S> {
    fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());

        let rbio = try!(MemBio::new());
        let wbio = try!(MemBio::new());

        unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }

        Ok(IndirectStream {
            stream: stream,
            ssl: Arc::new(ssl),
@@ -852,27 +853,139 @@ impl<S: Read+Write> Write for IndirectStream<S> {
    }
}

#[derive(Clone)]
struct DirectStream<S> {
    stream: S,
    ssl: Arc<Ssl>,
}

impl DirectStream<net::TcpStream> {
    fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> {
        Ok(DirectStream {
            stream: try!(self.stream.try_clone()),
            ssl: self.ssl.clone(),
        })
    }
}

impl<S> DirectStream<S> {
    fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
        unsafe {
            let bio = ffi::BIO_new_socket(sock, 0);
            if bio == ptr::null_mut() {
                return Err(SslError::get());
            }
            ffi::SSL_set_bio(ssl.ssl, bio, bio);
        }

        Ok(DirectStream {
            stream: stream,
            ssl: Arc::new(ssl),
        })
    }

    fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
        let ssl = try!(DirectStream::new_base(ssl, stream, sock));
        let ret = ssl.ssl.connect();
        if ret > 0 {
            Ok(ssl)
        } else {
            Err(ssl.make_error(ret))
        }
    }

    fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
        let ssl = try!(DirectStream::new_base(ssl, stream, sock));
        let ret = ssl.ssl.accept();
        if ret > 0 {
            Ok(ssl)
        } else {
            Err(ssl.make_error(ret))
        }
    }

    fn make_error(&self, ret: c_int) -> SslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => SslError::get(),
            LibSslError::ErrorSyscall => {
                let err = SslError::get();
                let count = match err {
                    SslError::OpenSslErrors(ref v) => v.len(),
                    _ => unreachable!(),
                };
                if count == 0 {
                    if ret == 0 {
                        SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                             "unexpected EOF observed"))
                    } else {
                        SslError::StreamError(io::Error::last_os_error())
                    }
                } else {
                    err
                }
            }
            err => panic!("unexpected error {:?} with ret {}", err, ret),
        }
    }
}

impl<S> Read for DirectStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        }
    }
}

impl<S: Write> Write for DirectStream<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let ret = self.ssl.write(buf);
        if ret > 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        self.stream.flush()
    }
}

#[derive(Clone)]
enum StreamKind<S> {
    Indirect(IndirectStream<S>),
    Direct(DirectStream<S>),
}

impl<S> StreamKind<S> {
    fn stream(&self) -> &S {
        match *self {
            StreamKind::Indirect(ref s) => &s.stream
            StreamKind::Indirect(ref s) => &s.stream,
            StreamKind::Direct(ref s) => &s.stream,
        }
    }

    fn mut_stream(&mut self) -> &mut S {
        match *self {
            StreamKind::Indirect(ref mut s) => &mut s.stream
            StreamKind::Indirect(ref mut s) => &mut s.stream,
            StreamKind::Direct(ref mut s) => &mut s.stream,
        }
    }

    fn ssl(&self) -> &Ssl {
        match *self {
            StreamKind::Indirect(ref s) => &s.ssl
            StreamKind::Indirect(ref s) => &s.ssl,
            StreamKind::Direct(ref s) => &s.ssl,
        }
    }
}
@@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> {
    /// Create a new independently owned handle to the underlying socket.
    pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
        let kind = match self.kind {
            StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
            StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())),
            StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone()))
        };
        Ok(SslStream {
            kind: kind
@@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
    }
}

#[cfg(unix)]
impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> {
    pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let stream = try!(DirectStream::new_client(ssl, stream, fd));
        Ok(SslStream {
            kind: StreamKind::Direct(stream)
        })
    }

    pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let ssl = try!(ssl.into_ssl());
        let fd = stream.as_raw_fd() as c_int;
        let stream = try!(DirectStream::new_server(ssl, stream, fd));
        Ok(SslStream {
            kind: StreamKind::Direct(stream)
        })
    }
}

#[cfg(windows)]
impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> {
    pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let fd = stream.as_raw_socket() as c_int;
        let stream = try!(DirectStream::new_client(ssl, stream, fd));
        Ok(SslStream {
            kind: StreamKind::Direct(stream)
        })
    }

    pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let fd = stream.as_raw_socket() as c_int;
        let stream = try!(DirectStream::new_server(ssl, stream, fd));
        Ok(SslStream {
            kind: StreamKind::Direct(stream)
        })
    }
}

impl<S: Read+Write> SslStream<S> {
    pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
        let stream = try!(IndirectStream::new_client(ssl, stream));
@@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> {
impl<S: Read+Write> Read for SslStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.read(buf)
            StreamKind::Indirect(ref mut s) => s.read(buf),
            StreamKind::Direct(ref mut s) => s.read(buf),
        }
    }
}
@@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> {
impl<S: Read+Write> Write for SslStream<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.write(buf)
            StreamKind::Indirect(ref mut s) => s.write(buf),
            StreamKind::Direct(ref mut s) => s.write(buf),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        match self.kind {
            StreamKind::Indirect(ref mut s) => s.flush()
            StreamKind::Indirect(ref mut s) => s.flush(),
            StreamKind::Direct(ref mut s) => s.flush(),
        }
    }
}
+18 −1
Original line number Diff line number Diff line
@@ -317,8 +317,17 @@ fn test_write() {
    stream.flush().unwrap();
}

#[test]
fn test_write_direct() {
    let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
    let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
    stream.write_all("hello".as_bytes()).unwrap();
    stream.flush().unwrap();
    stream.write_all(" there".as_bytes()).unwrap();
    stream.flush().unwrap();
}

run_test!(get_peer_certificate, |method, stream| {
    //let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
    let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap();
    let cert = stream.get_peer_certificate().unwrap();
    let fingerprint = cert.fingerprint(SHA256).unwrap();
@@ -349,6 +358,14 @@ fn test_read() {
    io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
}

#[test]
fn test_read_direct() {
    let tcp = TcpStream::connect("127.0.0.1:15418").unwrap();
    let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap();
    stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap();
    stream.flush().unwrap();
    io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
}

#[test]
fn test_pending() {