Unverified Commit 1092196c authored by Steven Fackler's avatar Steven Fackler
Browse files

Add uninit read methods and switch to _ex when possible

parent ea075d12
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -640,7 +640,13 @@ extern "C" {
    pub fn SSL_stateless(s: *mut SSL) -> c_int;
    pub fn SSL_connect(ssl: *mut SSL) -> c_int;
    pub fn SSL_read(ssl: *mut SSL, buf: *mut c_void, num: c_int) -> c_int;
    #[cfg(any(ossl111, libressl350))]
    pub fn SSL_read_ex(ssl: *mut SSL, buf: *mut c_void, num: usize, readbytes: *mut usize)
        -> c_int;
    pub fn SSL_peek(ssl: *mut SSL, buf: *mut c_void, num: c_int) -> c_int;
    #[cfg(any(ossl111, libressl350))]
    pub fn SSL_peek_ex(ssl: *mut SSL, buf: *mut c_void, num: usize, readbytes: *mut usize)
        -> c_int;
    #[cfg(any(ossl111, libressl340))]
    pub fn SSL_read_early_data(
        s: *mut SSL,
@@ -661,6 +667,13 @@ extern "C" {

extern "C" {
    pub fn SSL_write(ssl: *mut SSL, buf: *const c_void, num: c_int) -> c_int;
    #[cfg(any(ossl111, libressl350))]
    pub fn SSL_write_ex(
        ssl: *mut SSL,
        buf: *const c_void,
        num: size_t,
        written: *mut size_t,
    ) -> c_int;
    #[cfg(any(ossl111, libressl340))]
    pub fn SSL_write_early_data(
        s: *mut SSL,
+144 −64
Original line number Diff line number Diff line
@@ -90,14 +90,13 @@ use libc::{c_char, c_int, c_long, c_uchar, c_uint, c_void};
use once_cell::sync::{Lazy, OnceCell};
use openssl_macros::corresponds;
use std::any::TypeId;
use std::cmp;
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::marker::PhantomData;
use std::mem::{self, ManuallyDrop};
use std::mem::{self, ManuallyDrop, MaybeUninit};
use std::ops::{Deref, DerefMut};
use std::panic::resume_unwind;
use std::path::Path;
@@ -2367,21 +2366,6 @@ impl SslRef {
        unsafe { ffi::SSL_get_rbio(self.as_ptr()) }
    }

    fn read(&mut self, buf: &mut [u8]) -> c_int {
        let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
        unsafe { ffi::SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) }
    }

    fn peek(&mut self, buf: &mut [u8]) -> c_int {
        let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
        unsafe { ffi::SSL_peek(self.as_ptr(), buf.as_ptr() as *mut c_void, len) }
    }

    fn write(&mut self, buf: &[u8]) -> c_int {
        let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
        unsafe { ffi::SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) }
    }

    fn get_error(&self, ret: c_int) -> ErrorCode {
        unsafe { ErrorCode::from_raw(ffi::SSL_get_error(self.as_ptr(), ret)) }
    }
@@ -3750,63 +3734,167 @@ impl<S: Read + Write> SslStream<S> {
        }
    }

    /// Like `read`, but takes a possibly-uninitialized slice.
    ///
    /// # Safety
    ///
    /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`,
    /// then the first `n` bytes of `buf` are guaranteed to be initialized.
    #[corresponds(SSL_read_ex)]
    pub fn read_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
        loop {
            match self.ssl_read_uninit(buf) {
                Ok(n) => return Ok(n),
                Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0),
                Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {
                    return Ok(0);
                }
                Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {}
                Err(e) => {
                    return Err(e
                        .into_io_error()
                        .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)));
                }
            }
        }
    }

    /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`.
    ///
    /// It is particularly useful with a non-blocking socket, where the error value will identify if
    /// OpenSSL is waiting on read or write readiness.
    #[corresponds(SSL_read)]
    #[corresponds(SSL_read_ex)]
    pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
        // The interpretation of the return code here is a little odd with a
        // zero-length write. OpenSSL will likely correctly report back to us
        // that it read zero bytes, but zero is also the sentinel for "error".
        // To avoid that confusion short-circuit that logic and return quickly
        // if `buf` has a length of zero.
        // SAFETY: `ssl_read_uninit` does not de-initialize the buffer.
        unsafe {
            self.ssl_read_uninit(slice::from_raw_parts_mut(
                buf.as_mut_ptr().cast::<MaybeUninit<u8>>(),
                buf.len(),
            ))
        }
    }

    /// Like `read_ssl`, but takes a possibly-uninitialized slice.
    ///
    /// # Safety
    ///
    /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`,
    /// then the first `n` bytes of `buf` are guaranteed to be initialized.
    #[corresponds(SSL_read_ex)]
    pub fn ssl_read_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> Result<usize, Error> {
        cfg_if! {
            if #[cfg(any(ossl111, libressl350))] {
                let mut readbytes = 0;
                let ret = unsafe {
                    ffi::SSL_read_ex(
                        self.ssl().as_ptr(),
                        buf.as_mut_ptr().cast(),
                        buf.len(),
                        &mut readbytes,
                    )
                };

                if ret > 0 {
                    Ok(readbytes)
                } else {
                    Err(self.make_error(ret))
                }
            } else {
                if buf.is_empty() {
                    return Ok(0);
                }

        let ret = self.ssl.read(buf);
                let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int;
                let ret = unsafe {
                    ffi::SSL_read(self.ssl().as_ptr(), buf.as_mut_ptr().cast(), len)
                };
                if ret > 0 {
                    Ok(ret as usize)
                } else {
                    Err(self.make_error(ret))
                }
            }
        }
    }

    /// Like `write`, but returns an `ssl::Error` rather than an `io::Error`.
    ///
    /// It is particularly useful with a non-blocking socket, where the error value will identify if
    /// OpenSSL is waiting on read or write readiness.
    #[corresponds(SSL_write)]
    #[corresponds(SSL_write_ex)]
    pub fn ssl_write(&mut self, buf: &[u8]) -> Result<usize, Error> {
        // See above for why we short-circuit on zero-length buffers
        cfg_if! {
            if #[cfg(any(ossl111, libressl350))] {
                let mut written = 0;
                let ret = unsafe {
                    ffi::SSL_write_ex(
                        self.ssl().as_ptr(),
                        buf.as_ptr().cast(),
                        buf.len(),
                        &mut written,
                    )
                };

                if ret > 0 {
                    Ok(written)
                } else {
                    Err(self.make_error(ret))
                }
            } else {
                if buf.is_empty() {
                    return Ok(0);
                }

        let ret = self.ssl.write(buf);
                let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int;
                let ret = unsafe {
                    ffi::SSL_write(self.ssl().as_ptr(), buf.as_ptr().cast(), len)
                };
                if ret > 0 {
                    Ok(ret as usize)
                } else {
                    Err(self.make_error(ret))
                }
            }
        }
    }

    /// Reads data from the stream, without removing it from the queue.
    #[corresponds(SSL_peek)]
    #[corresponds(SSL_peek_ex)]
    pub fn ssl_peek(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
        // See above for why we short-circuit on zero-length buffers
        cfg_if! {
            if #[cfg(any(ossl111, libressl350))] {
                let mut readbytes = 0;
                let ret = unsafe {
                    ffi::SSL_peek_ex(
                        self.ssl().as_ptr(),
                        buf.as_mut_ptr().cast(),
                        buf.len(),
                        &mut readbytes,
                    )
                };

                if ret > 0 {
                    Ok(readbytes)
                } else {
                    Err(self.make_error(ret))
                }
            } else {
                if buf.is_empty() {
                    return Ok(0);
                }

        let ret = self.ssl.peek(buf);
                let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int;
                let ret = unsafe {
                    ffi::SSL_peek(self.ssl().as_ptr(), buf.as_mut_ptr().cast(), len)
                };
                if ret > 0 {
                    Ok(ret as usize)
                } else {
                    Err(self.make_error(ret))
                }
            }
        }
    }

    /// Shuts down the session.
    ///
@@ -3910,20 +3998,12 @@ impl<S> SslStream<S> {

impl<S: Read + Write> Read for SslStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        loop {
            match self.ssl_read(buf) {
                Ok(n) => return Ok(n),
                Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0),
                Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {
                    return Ok(0);
                }
                Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {}
                Err(e) => {
                    return Err(e
                        .into_io_error()
                        .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)));
                }
            }
        // SAFETY: `read_uninit` does not de-initialize the buffer
        unsafe {
            self.read_uninit(slice::from_raw_parts_mut(
                buf.as_mut_ptr().cast::<MaybeUninit<u8>>(),
                buf.len(),
            ))
        }
    }
}