Unverified Commit 78abc9b6 authored by Steven Fackler's avatar Steven Fackler Committed by GitHub
Browse files

Merge pull request #916 from sfackler/ssl-callback-cleanup

Clean up SSL callbacks
parents af5a7178 c25b6f3e
Loading
Loading
Loading
Loading
+36 −29
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ use std::ptr;
use std::slice;
#[cfg(ossl111)]
use std::str;
use std::sync::Arc;

use dh::Dh;
#[cfg(any(ossl101, ossl102))]
@@ -20,10 +21,7 @@ use pkey::Params;
use ssl::AlpnError;
#[cfg(ossl111)]
use ssl::ExtensionContext;
use ssl::{
    get_ssl_callback_idx, SniError, SslAlert, SslContext, SslContextRef, SslRef, SslSession,
    SslSessionRef,
};
use ssl::{SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef};
#[cfg(ossl111)]
use x509::X509Ref;
use x509::{X509StoreContext, X509StoreContextRef};
@@ -94,14 +92,17 @@ where
    F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send,
{
    unsafe {
        let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx();
        let ssl = ffi::X509_STORE_CTX_get_ex_data(x509_ctx, idx);
        let verify = ffi::SSL_get_ex_data(ssl as *const _, get_ssl_callback_idx::<F>());
        let verify: &F = &*(verify as *mut F);

        let ctx = X509StoreContextRef::from_ptr_mut(x509_ctx);
        let ssl_idx = X509StoreContext::ssl_idx().expect("BUG: store context ssl index missing");
        let callback_idx = Ssl::cached_ex_index::<Arc<F>>();

        let callback = ctx.ex_data(ssl_idx)
            .expect("BUG: store context missing ssl")
            .ex_data(callback_idx)
            .expect("BUG: ssl verify callback missing")
            .clone();

        verify(preverify_ok != 0, ctx) as c_int
        callback(preverify_ok != 0, ctx) as c_int
    }
}

@@ -216,10 +217,11 @@ pub unsafe extern "C" fn raw_tmp_dh_ssl<F>(
where
    F: Fn(&mut SslRef, bool, u32) -> Result<Dh<Params>, ErrorStack> + 'static + Sync + Send,
{
    let callback = ffi::SSL_get_ex_data(ssl, get_ssl_callback_idx::<F>());
    let callback = &*(callback as *mut F);

    let ssl = SslRef::from_ptr_mut(ssl);
    let callback = ssl.ex_data(Ssl::cached_ex_index::<Arc<F>>())
        .expect("BUG: ssl tmp dh callback missing")
        .clone();

    match callback(ssl, is_export != 0, keylength as u32) {
        Ok(dh) => {
            let ptr = dh.as_ptr();
@@ -242,10 +244,11 @@ pub unsafe extern "C" fn raw_tmp_ecdh_ssl<F>(
where
    F: Fn(&mut SslRef, bool, u32) -> Result<EcKey<Params>, ErrorStack> + 'static + Sync + Send,
{
    let callback = ffi::SSL_get_ex_data(ssl, get_ssl_callback_idx::<F>());
    let callback = &*(callback as *mut F);

    let ssl = SslRef::from_ptr_mut(ssl);
    let callback = ssl.ex_data(Ssl::cached_ex_index::<Arc<F>>())
        .expect("BUG: ssl tmp ecdh callback missing")
        .clone();

    match callback(ssl, is_export != 0, keylength as u32) {
        Ok(ec_key) => {
            let ptr = ec_key.as_ptr();
@@ -503,18 +506,20 @@ where
        match (*callback)(ssl, ectx, cert) {
            Ok(None) => 0,
            Ok(Some(buf)) => {
                *outlen = buf.as_ref().len() as size_t;
                *outlen = buf.as_ref().len();
                *out = buf.as_ref().as_ptr();

                let idx = get_ssl_callback_idx::<CustomExtAddState<T>>();
                let ptr = ffi::SSL_get_ex_data(ssl.as_ptr(), idx);
                if ptr.is_null() {
                    let x = Box::into_raw(Box::<CustomExtAddState<T>>::new(CustomExtAddState(
                        Some(buf),
                    ))) as *mut c_void;
                    ffi::SSL_set_ex_data(ssl.as_ptr(), idx, x);
                } else {
                    *(ptr as *mut _) = CustomExtAddState(Some(buf))
                let idx = Ssl::cached_ex_index::<CustomExtAddState<T>>();
                let mut buf = Some(buf);
                let new = match ssl.ex_data_mut(idx) {
                    Some(state) => {
                        state.0 = buf.take();
                        false
                    }
                    None => true,
                };
                if new {
                    ssl.set_ex_data(idx, CustomExtAddState(buf));
                }
                1
            }
@@ -537,9 +542,11 @@ pub extern "C" fn raw_custom_ext_free<T>(
    T: 'static + Sync + Send,
{
    unsafe {
        let state = ffi::SSL_get_ex_data(ssl, get_ssl_callback_idx::<CustomExtAddState<T>>());
        let state = &mut (*(state as *mut CustomExtAddState<T>)).0;
        state.take();
        let ssl = SslRef::from_ptr_mut(ssl);
        let idx = Ssl::cached_ex_index::<CustomExtAddState<T>>();
        if let Some(state) = ssl.ex_data_mut(idx) {
            state.0 = None;
        }
    }
}

+40 −39
Original line number Diff line number Diff line
@@ -75,7 +75,7 @@ use std::path::Path;
use std::ptr;
use std::slice;
use std::str;
use std::sync::Mutex;
use std::sync::{Arc, Mutex};

use dh::{Dh, DhRef};
#[cfg(any(ossl101, ossl102))]
@@ -466,14 +466,6 @@ lazy_static! {
    static ref SSL_INDEXES: Mutex<HashMap<TypeId, c_int>> = Mutex::new(HashMap::new());
}

fn get_ssl_callback_idx<T: 'static>() -> c_int {
    *SSL_INDEXES
        .lock()
        .unwrap()
        .entry(TypeId::of::<T>())
        .or_insert_with(|| get_new_ssl_idx::<T>())
}

unsafe extern "C" fn free_data_box<T>(
    _parent: *mut c_void,
    ptr: *mut c_void,
@@ -487,14 +479,6 @@ unsafe extern "C" fn free_data_box<T>(
    }
}

fn get_new_ssl_idx<T>() -> c_int {
    unsafe {
        let idx = compat::get_new_ssl_idx(free_data_box::<T>);
        assert!(idx >= 0);
        idx
    }
}

/// An error returned from the SNI callback.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct SniError(c_int);
@@ -1937,6 +1921,21 @@ impl Ssl {
            Ok(Index::from_raw(idx))
        }
    }

    // FIXME should return a result?
    fn cached_ex_index<T>() -> Index<Ssl, T>
    where
        T: 'static + Sync + Send,
    {
        unsafe {
            let idx = *SSL_INDEXES
                .lock()
                .unwrap_or_else(|e| e.into_inner())
                .entry(TypeId::of::<T>())
                .or_insert_with(|| Ssl::new_ex_index::<T>().unwrap().as_raw());
            Index::from_raw(idx)
        }
    }
}

impl fmt::Debug for SslRef {
@@ -1988,12 +1987,8 @@ impl SslRef {
        F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send,
    {
        unsafe {
            let verify = Box::new(verify);
            ffi::SSL_set_ex_data(
                self.as_ptr(),
                get_ssl_callback_idx::<F>(),
                mem::transmute(verify),
            );
            // this needs to be in an Arc since the callback can register a new callback!
            self.set_ex_data(Ssl::cached_ex_index(), Arc::new(verify));
            ffi::SSL_set_verify(self.as_ptr(), mode.bits as c_int, Some(ssl_raw_verify::<F>));
        }
    }
@@ -2019,14 +2014,9 @@ impl SslRef {
        F: Fn(&mut SslRef, bool, u32) -> Result<Dh<Params>, ErrorStack> + 'static + Sync + Send,
    {
        unsafe {
            let callback = Box::new(callback);
            ffi::SSL_set_ex_data(
                self.as_ptr(),
                get_ssl_callback_idx::<F>(),
                Box::into_raw(callback) as *mut c_void,
            );
            let f: unsafe extern "C" fn(_, _, _) -> _ = raw_tmp_dh_ssl::<F>;
            ffi::SSL_set_tmp_dh_callback(self.as_ptr(), f);
            // this needs to be in an Arc since the callback can register a new callback!
            self.set_ex_data(Ssl::cached_ex_index(), Arc::new(callback));
            ffi::SSL_set_tmp_dh_callback(self.as_ptr(), raw_tmp_dh_ssl::<F>);
        }
    }

@@ -2052,14 +2042,9 @@ impl SslRef {
        F: Fn(&mut SslRef, bool, u32) -> Result<EcKey<Params>, ErrorStack> + 'static + Sync + Send,
    {
        unsafe {
            let callback = Box::new(callback);
            ffi::SSL_set_ex_data(
                self.as_ptr(),
                get_ssl_callback_idx::<F>(),
                Box::into_raw(callback) as *mut c_void,
            );
            let f: unsafe extern "C" fn(_, _, _) -> _ = raw_tmp_ecdh_ssl::<F>;
            ffi::SSL_set_tmp_ecdh_callback(self.as_ptr(), f);
            // this needs to be in an Arc since the callback can register a new callback!
            self.set_ex_data(Ssl::cached_ex_index(), Arc::new(callback));
            ffi::SSL_set_tmp_ecdh_callback(self.as_ptr(), raw_tmp_ecdh_ssl::<F>);
        }
    }

@@ -2538,6 +2523,22 @@ impl SslRef {
            }
        }
    }

    /// Returns a mutable reference to the extra data at the specified index.
    ///
    /// This corresponds to [`SSL_get_ex_data`].
    ///
    /// [`SSL_get_ex_data`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_ex_data.html
    pub fn ex_data_mut<T>(&mut self, index: Index<Ssl, T>) -> Option<&mut T> {
        unsafe {
            let data = ffi::SSL_get_ex_data(self.as_ptr(), index.as_raw());
            if data.is_null() {
                None
            } else {
                Some(&mut *(data as *mut T))
            }
        }
    }
}

unsafe impl Sync for Ssl {}