Commit 16183f41 authored by Ted Mielczarek's avatar Ted Mielczarek
Browse files

Expose PSK via a SslContextBuilder::set_psk_callback method

parent 52e82faa
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -2252,6 +2252,14 @@ extern "C" {
                                                              arg: *mut c_void)
                                                              -> c_int,
                                            arg: *mut c_void);
    pub fn SSL_CTX_set_psk_client_callback(ssl: *mut SSL_CTX,
                                           psk_client_cb: extern "C" fn(*mut SSL,
                                                                        *const c_char,
                                                                        *mut c_char,
                                                                        c_uint,
                                                                        *mut c_uchar,
                                                                        c_uint)
                                                                        -> c_uint);
    pub fn SSL_select_next_proto(out: *mut *mut c_uchar,
                                 outlen: *mut c_uchar,
                                 inbuf: *const c_uchar,
+47 −2
Original line number Diff line number Diff line
@@ -73,10 +73,10 @@
use ffi;
use foreign_types::{ForeignType, ForeignTypeRef};
use libc::{c_int, c_void, c_long, c_ulong};
use libc::{c_uchar, c_uint};
use libc::{c_char, c_uchar, c_uint};
use std::any::Any;
use std::any::TypeId;
use std::borrow::Borrow;
use std::borrow::{Borrow, Cow};
use std::cmp;
use std::collections::HashMap;
use std::ffi::{CStr, CString};
@@ -307,6 +307,39 @@ extern "C" fn raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_
    }
}

extern "C" fn raw_psk<F>(ssl: *mut ffi::SSL,
                         hint: *const c_char,
                         identity: *mut c_char,
                         _max_identity_len: c_uint,
                         psk: *mut c_uchar,
                         _max_psk_len: c_uint) -> c_uint
    where F: Fn(&mut SslRef, &str) -> Result<(String, Vec<u8>), ErrorStack> + Any + 'static + Sync + Send
{
    unsafe {
        let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl as *const _);
        let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::<F>());
        let ssl = SslRef::from_ptr_mut(ssl);
        let callback = &*(callback as *mut F);
        let cstr_hint = if hint != ptr::null() { Cow::Borrowed(CStr::from_ptr(hint)) } else { Cow::Owned(CString::new("").unwrap()) };
        if let Ok(s) = cstr_hint.to_str() {
            match callback(ssl, s) {
                Ok((identity_out, psk_out)) => {
                    if let Ok(id) = CString::new(identity_out) {
                        //TODO: validate max_identity_len, max_psk_len
                        let id = id.into_bytes_with_nul();
                        ptr::copy(id.as_ptr() as *mut i8, identity, id.len());
                        ptr::copy(psk_out.as_ptr(), psk, psk_out.len());
                        psk_out.len() as u32
                    } else { 0 }
                }
                Err(_) => 0,
            }
        } else {
            0
        }
    }
}

extern "C" fn ssl_raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int
    where F: Fn(bool, &X509StoreContextRef) -> bool + Any + 'static + Sync + Send
{
@@ -977,6 +1010,18 @@ impl SslContextBuilder {
        }
    }

    pub fn set_psk_callback<F>(&mut self, callback: F)
        where F: Fn(&mut SslRef, &str) -> Result<(String, Vec<u8>), ErrorStack> + Any + 'static + Sync + Send
    {
        unsafe {
            let callback = Box::new(callback);
            ffi::SSL_CTX_set_ex_data(self.as_ptr(),
                                     get_callback_idx::<F>(),
                                     mem::transmute(callback));
            ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), raw_psk::<F>)
        }
    }

    pub fn build(self) -> SslContext {
        let ctx = SslContext(self.0);
        mem::forget(self);