Unverified Commit 6440ee04 authored by Steven Fackler's avatar Steven Fackler Committed by GitHub
Browse files

Merge pull request #943 from lolzballs/master

 Add wrapper for SSL_CTX_set_psk_server_callback 
parents 0f815628 bf86580b
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -2598,6 +2598,14 @@ extern "C" {
                -> c_uint,
        >,
    );
    #[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
    pub fn SSL_CTX_set_psk_server_callback(
        ssl: *mut SSL_CTX,
        psk_server_cb: Option<
            extern "C" fn(*mut SSL, *const c_char, *mut c_uchar, c_uint)
                -> c_uint,
        >,
    );

    pub fn SSL_select_next_proto(
        out: *mut *mut c_uchar,
+43 −3
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ where
}

#[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
pub extern "C" fn raw_psk<F>(
pub extern "C" fn raw_client_psk<F>(
    ssl: *mut ffi::SSL,
    hint: *const c_char,
    identity: *mut c_char,
@@ -74,7 +74,7 @@ where
            .ssl_context()
            .ex_data(callback_idx)
            .expect("BUG: psk callback missing") as *const F;
        let hint = if hint != ptr::null() {
        let hint = if !hint.is_null() {
            Some(CStr::from_ptr(hint).to_bytes())
        } else {
            None
@@ -84,7 +84,47 @@ where
        let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize);
        match (*callback)(ssl, hint, identity_sl, psk_sl) {
            Ok(psk_len) => psk_len as u32,
            _ => 0,
            Err(e) => {
                e.put();
                0
            }
        }
    }
}

#[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
pub extern "C" fn raw_server_psk<F>(
    ssl: *mut ffi::SSL,
    identity: *const c_char,
    psk: *mut c_uchar,
    max_psk_len: c_uint,
) -> c_uint
where
    F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack>
        + 'static
        + Sync
        + Send,
{
    unsafe {
        let ssl = SslRef::from_ptr_mut(ssl);
        let callback_idx = SslContext::cached_ex_index::<F>();

        let callback = ssl.ssl_context()
            .ex_data(callback_idx)
            .expect("BUG: psk callback missing") as *const F;
        let identity = if identity != ptr::null() {
            Some(CStr::from_ptr(identity).to_bytes())
        } else {
            None
        };
        // Give the callback mutable slices into which it can write the psk.
        let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize);
        match (*callback)(ssl, identity, psk_sl) {
            Ok(psk_len) => psk_len as u32,
            Err(e) => {
                e.put();
                0
            }
        }
    }
}
+36 −1
Original line number Diff line number Diff line
@@ -1252,16 +1252,51 @@ impl SslContextBuilder {
    ///
    /// [`SSL_CTX_set_psk_client_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_client_callback.html
    #[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
    pub fn set_psk_client_callback<F>(&mut self, callback: F)
    where
        F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack>
            + 'static
            + Sync
            + Send,
    {
        unsafe {
            self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
            ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_client_psk::<F>));
        }
    }

    #[deprecated(since = "0.10.10", note = "renamed to `set_psk_client_callback`")]
    #[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
    pub fn set_psk_callback<F>(&mut self, callback: F)
    where
        F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack>
            + 'static
            + Sync
            + Send,
    {
        self.set_psk_client_callback(callback)
    }

    /// Sets the callback for providing an identity and pre-shared key for a TLS-PSK server.
    ///
    /// The callback will be called with the SSL context, an identity provided by the client,
    /// and, a mutable slice for the pre-shared key bytes. The callback returns the number of
    /// bytes in the pre-shared key.
    ///
    /// This corresponds to [`SSL_CTX_set_psk_server_callback`].
    ///
    /// [`SSL_CTX_set_psk_server_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_server_callback.html
    #[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
    pub fn set_psk_server_callback<F>(&mut self, callback: F)
    where
        F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack>
            + 'static
            + Sync
            + Send,
    {
        unsafe {
            self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
            ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_psk::<F>));
            ffi::SSL_CTX_set_psk_server_callback(self.as_ptr(), Some(raw_server_psk::<F>));
        }
    }

+49 −0
Original line number Diff line number Diff line
@@ -1568,3 +1568,52 @@ fn stateless() {
    send(client_stream.get_mut(), server_stream.get_mut());
    hs(server_stream.handshake()).unwrap();
}

#[cfg(not(osslconf = "OPENSSL_NO_PSK"))]
#[test]
fn psk_ciphers() {
    const CIPHER: &'static str = "PSK-AES128-CBC-SHA";
    const PSK: &[u8] = b"thisisaverysecurekey";
    const CLIENT_IDENT: &[u8] = b"thisisaclient";
    static CLIENT_CALLED: AtomicBool = ATOMIC_BOOL_INIT;
    static SERVER_CALLED: AtomicBool = ATOMIC_BOOL_INIT;

    let listener = TcpListener::bind("127.0.0.1:0").unwrap();
    let port = listener.local_addr().unwrap().port();

    thread::spawn(move || {
        let stream = listener.accept().unwrap().0;
        let mut ctx = SslContext::builder(SslMethod::tls()).unwrap();
        ctx.set_cipher_list(CIPHER).unwrap();
        ctx.set_psk_server_callback(move |_, identity, psk| {
            assert!(identity.unwrap_or(&[]) == CLIENT_IDENT);
            psk[..PSK.len()].copy_from_slice(&PSK);
            SERVER_CALLED.store(true, Ordering::SeqCst);
            Ok(PSK.len())
        });
        let ssl = Ssl::new(&ctx.build()).unwrap();
        ssl.accept(stream).unwrap();
    });

    let stream = TcpStream::connect(("127.0.0.1", port)).unwrap();
    let mut ctx = SslContext::builder(SslMethod::tls()).unwrap();
    // TLS 1.3 has no DH suites, and openssl isn't happy if the max version has no suites :(
    #[cfg(ossl111)]
    {
        ctx.set_options(super::SslOptions {
            bits: ::ffi::SSL_OP_NO_TLSv1_3,
        });
    }
    ctx.set_cipher_list(CIPHER).unwrap();
    ctx.set_psk_client_callback(move |_, _, identity, psk| {
        identity[..CLIENT_IDENT.len()].copy_from_slice(&CLIENT_IDENT);
        identity[CLIENT_IDENT.len()] = 0;
        psk[..PSK.len()].copy_from_slice(&PSK);
        CLIENT_CALLED.store(true, Ordering::SeqCst);
        Ok(PSK.len())
    });
    let ssl = Ssl::new(&ctx.build()).unwrap();
    ssl.connect(stream).unwrap();

    assert!(CLIENT_CALLED.load(Ordering::SeqCst) && SERVER_CALLED.load(Ordering::SeqCst));
}