Loading openssl-sys/src/lib.rs +8 −0 Original line number Diff line number Diff line Loading @@ -2598,6 +2598,14 @@ extern "C" { -> c_uint, >, ); #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn SSL_CTX_set_psk_server_callback( ssl: *mut SSL_CTX, psk_server_cb: Option< extern "C" fn(*mut SSL, *const c_char, *mut c_uchar, c_uint) -> c_uint, >, ); pub fn SSL_select_next_proto( out: *mut *mut c_uchar, Loading openssl/src/ssl/callbacks.rs +43 −3 Original line number Diff line number Diff line Loading @@ -52,7 +52,7 @@ where } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub extern "C" fn raw_psk<F>( pub extern "C" fn raw_client_psk<F>( ssl: *mut ffi::SSL, hint: *const c_char, identity: *mut c_char, Loading @@ -74,7 +74,7 @@ where .ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let hint = if hint != ptr::null() { let hint = if !hint.is_null() { Some(CStr::from_ptr(hint).to_bytes()) } else { None Loading @@ -84,7 +84,47 @@ where let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize); match (*callback)(ssl, hint, identity_sl, psk_sl) { Ok(psk_len) => psk_len as u32, _ => 0, Err(e) => { e.put(); 0 } } } } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub extern "C" fn raw_server_psk<F>( ssl: *mut ffi::SSL, identity: *const c_char, psk: *mut c_uchar, max_psk_len: c_uint, ) -> c_uint where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { let ssl = SslRef::from_ptr_mut(ssl); let callback_idx = SslContext::cached_ex_index::<F>(); let callback = ssl.ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let identity = if identity != ptr::null() { Some(CStr::from_ptr(identity).to_bytes()) } else { None }; // Give the callback mutable slices into which it can write the psk. let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize); match (*callback)(ssl, identity, psk_sl) { Ok(psk_len) => psk_len as u32, Err(e) => { e.put(); 0 } } } } Loading openssl/src/ssl/mod.rs +36 −1 Original line number Diff line number Diff line Loading @@ -1252,16 +1252,51 @@ impl SslContextBuilder { /// /// [`SSL_CTX_set_psk_client_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_client_callback.html #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_client_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_client_psk::<F>)); } } #[deprecated(since = "0.10.10", note = "renamed to `set_psk_client_callback`")] #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { self.set_psk_client_callback(callback) } /// Sets the callback for providing an identity and pre-shared key for a TLS-PSK server. /// /// The callback will be called with the SSL context, an identity provided by the client, /// and, a mutable slice for the pre-shared key bytes. The callback returns the number of /// bytes in the pre-shared key. /// /// This corresponds to [`SSL_CTX_set_psk_server_callback`]. /// /// [`SSL_CTX_set_psk_server_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_server_callback.html #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_server_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_psk::<F>)); ffi::SSL_CTX_set_psk_server_callback(self.as_ptr(), Some(raw_server_psk::<F>)); } } Loading openssl/src/ssl/test.rs +49 −0 Original line number Diff line number Diff line Loading @@ -1568,3 +1568,52 @@ fn stateless() { send(client_stream.get_mut(), server_stream.get_mut()); hs(server_stream.handshake()).unwrap(); } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] #[test] fn psk_ciphers() { const CIPHER: &'static str = "PSK-AES128-CBC-SHA"; const PSK: &[u8] = b"thisisaverysecurekey"; const CLIENT_IDENT: &[u8] = b"thisisaclient"; static CLIENT_CALLED: AtomicBool = ATOMIC_BOOL_INIT; static SERVER_CALLED: AtomicBool = ATOMIC_BOOL_INIT; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); thread::spawn(move || { let stream = listener.accept().unwrap().0; let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_cipher_list(CIPHER).unwrap(); ctx.set_psk_server_callback(move |_, identity, psk| { assert!(identity.unwrap_or(&[]) == CLIENT_IDENT); psk[..PSK.len()].copy_from_slice(&PSK); SERVER_CALLED.store(true, Ordering::SeqCst); Ok(PSK.len()) }); let ssl = Ssl::new(&ctx.build()).unwrap(); ssl.accept(stream).unwrap(); }); let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); // TLS 1.3 has no DH suites, and openssl isn't happy if the max version has no suites :( #[cfg(ossl111)] { ctx.set_options(super::SslOptions { bits: ::ffi::SSL_OP_NO_TLSv1_3, }); } ctx.set_cipher_list(CIPHER).unwrap(); ctx.set_psk_client_callback(move |_, _, identity, psk| { identity[..CLIENT_IDENT.len()].copy_from_slice(&CLIENT_IDENT); identity[CLIENT_IDENT.len()] = 0; psk[..PSK.len()].copy_from_slice(&PSK); CLIENT_CALLED.store(true, Ordering::SeqCst); Ok(PSK.len()) }); let ssl = Ssl::new(&ctx.build()).unwrap(); ssl.connect(stream).unwrap(); assert!(CLIENT_CALLED.load(Ordering::SeqCst) && SERVER_CALLED.load(Ordering::SeqCst)); } Loading
openssl-sys/src/lib.rs +8 −0 Original line number Diff line number Diff line Loading @@ -2598,6 +2598,14 @@ extern "C" { -> c_uint, >, ); #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn SSL_CTX_set_psk_server_callback( ssl: *mut SSL_CTX, psk_server_cb: Option< extern "C" fn(*mut SSL, *const c_char, *mut c_uchar, c_uint) -> c_uint, >, ); pub fn SSL_select_next_proto( out: *mut *mut c_uchar, Loading
openssl/src/ssl/callbacks.rs +43 −3 Original line number Diff line number Diff line Loading @@ -52,7 +52,7 @@ where } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub extern "C" fn raw_psk<F>( pub extern "C" fn raw_client_psk<F>( ssl: *mut ffi::SSL, hint: *const c_char, identity: *mut c_char, Loading @@ -74,7 +74,7 @@ where .ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let hint = if hint != ptr::null() { let hint = if !hint.is_null() { Some(CStr::from_ptr(hint).to_bytes()) } else { None Loading @@ -84,7 +84,47 @@ where let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize); match (*callback)(ssl, hint, identity_sl, psk_sl) { Ok(psk_len) => psk_len as u32, _ => 0, Err(e) => { e.put(); 0 } } } } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub extern "C" fn raw_server_psk<F>( ssl: *mut ffi::SSL, identity: *const c_char, psk: *mut c_uchar, max_psk_len: c_uint, ) -> c_uint where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { let ssl = SslRef::from_ptr_mut(ssl); let callback_idx = SslContext::cached_ex_index::<F>(); let callback = ssl.ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let identity = if identity != ptr::null() { Some(CStr::from_ptr(identity).to_bytes()) } else { None }; // Give the callback mutable slices into which it can write the psk. let psk_sl = slice::from_raw_parts_mut(psk as *mut u8, max_psk_len as usize); match (*callback)(ssl, identity, psk_sl) { Ok(psk_len) => psk_len as u32, Err(e) => { e.put(); 0 } } } } Loading
openssl/src/ssl/mod.rs +36 −1 Original line number Diff line number Diff line Loading @@ -1252,16 +1252,51 @@ impl SslContextBuilder { /// /// [`SSL_CTX_set_psk_client_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_client_callback.html #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_client_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_client_psk::<F>)); } } #[deprecated(since = "0.10.10", note = "renamed to `set_psk_client_callback`")] #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { self.set_psk_client_callback(callback) } /// Sets the callback for providing an identity and pre-shared key for a TLS-PSK server. /// /// The callback will be called with the SSL context, an identity provided by the client, /// and, a mutable slice for the pre-shared key bytes. The callback returns the number of /// bytes in the pre-shared key. /// /// This corresponds to [`SSL_CTX_set_psk_server_callback`]. /// /// [`SSL_CTX_set_psk_server_callback`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_psk_server_callback.html #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_server_callback<F>(&mut self, callback: F) where F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8]) -> Result<usize, ErrorStack> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_psk_client_callback(self.as_ptr(), Some(raw_psk::<F>)); ffi::SSL_CTX_set_psk_server_callback(self.as_ptr(), Some(raw_server_psk::<F>)); } } Loading
openssl/src/ssl/test.rs +49 −0 Original line number Diff line number Diff line Loading @@ -1568,3 +1568,52 @@ fn stateless() { send(client_stream.get_mut(), server_stream.get_mut()); hs(server_stream.handshake()).unwrap(); } #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] #[test] fn psk_ciphers() { const CIPHER: &'static str = "PSK-AES128-CBC-SHA"; const PSK: &[u8] = b"thisisaverysecurekey"; const CLIENT_IDENT: &[u8] = b"thisisaclient"; static CLIENT_CALLED: AtomicBool = ATOMIC_BOOL_INIT; static SERVER_CALLED: AtomicBool = ATOMIC_BOOL_INIT; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); thread::spawn(move || { let stream = listener.accept().unwrap().0; let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_cipher_list(CIPHER).unwrap(); ctx.set_psk_server_callback(move |_, identity, psk| { assert!(identity.unwrap_or(&[]) == CLIENT_IDENT); psk[..PSK.len()].copy_from_slice(&PSK); SERVER_CALLED.store(true, Ordering::SeqCst); Ok(PSK.len()) }); let ssl = Ssl::new(&ctx.build()).unwrap(); ssl.accept(stream).unwrap(); }); let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); // TLS 1.3 has no DH suites, and openssl isn't happy if the max version has no suites :( #[cfg(ossl111)] { ctx.set_options(super::SslOptions { bits: ::ffi::SSL_OP_NO_TLSv1_3, }); } ctx.set_cipher_list(CIPHER).unwrap(); ctx.set_psk_client_callback(move |_, _, identity, psk| { identity[..CLIENT_IDENT.len()].copy_from_slice(&CLIENT_IDENT); identity[CLIENT_IDENT.len()] = 0; psk[..PSK.len()].copy_from_slice(&PSK); CLIENT_CALLED.store(true, Ordering::SeqCst); Ok(PSK.len()) }); let ssl = Ssl::new(&ctx.build()).unwrap(); ssl.connect(stream).unwrap(); assert!(CLIENT_CALLED.load(Ordering::SeqCst) && SERVER_CALLED.load(Ordering::SeqCst)); }