Loading openssl-sys/src/lib.rs +4 −0 Original line number Diff line number Diff line Loading @@ -1564,6 +1564,10 @@ pub unsafe fn SSL_CTX_set_tlsext_servername_callback( SSL_CTX_callback_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_CB, cb) } pub unsafe fn SSL_CTX_set_tlsext_servername_arg(ctx: *mut SSL_CTX, arg: *mut c_void) -> c_long { SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg) } pub unsafe fn SSL_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long { SSL_ctrl( s, Loading openssl/src/ssl/callbacks.rs +4 −6 Original line number Diff line number Diff line Loading @@ -109,7 +109,8 @@ where let ssl = SslRef::from_ptr_mut(ssl); let callback_idx = SslContext::cached_ex_index::<F>(); let callback = ssl.ssl_context() let callback = ssl .ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let identity = if identity != ptr::null() { Loading Loading @@ -152,16 +153,13 @@ where } } pub extern "C" fn raw_sni<F>(ssl: *mut ffi::SSL, al: *mut c_int, _arg: *mut c_void) -> c_int pub extern "C" fn raw_sni<F>(ssl: *mut ffi::SSL, al: *mut c_int, arg: *mut c_void) -> c_int where F: Fn(&mut SslRef, &mut SslAlert) -> Result<(), SniError> + 'static + Sync + Send, { unsafe { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl .ssl_context() .ex_data(SslContext::cached_ex_index::<F>()) .expect("BUG: sni callback missing") as *const F; let callback = arg as *const F; let mut alert = SslAlert(*al); let r = (*callback)(ssl, &mut alert); Loading openssl/src/ssl/mod.rs +15 −7 Original line number Diff line number Diff line Loading @@ -658,7 +658,14 @@ impl SslContextBuilder { F: Fn(&mut SslRef, &mut SslAlert) -> Result<(), SniError> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); // The SNI callback is somewhat unique in that the callback associated with the original // context associated with an SSL can be used even if the SSL's context has been swapped // out. When that happens, we wouldn't be able to look up the callback's state in the // context's ex data. Instead, pass the pointer directly as the servername arg. It's // still stored in ex data to manage the lifetime. let arg = self.set_ex_data_inner(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_tlsext_servername_arg(self.as_ptr(), arg); let f: extern "C" fn(_, _, _) -> _ = raw_sni::<F>; let f: extern "C" fn() = mem::transmute(f); ffi::SSL_CTX_set_tlsext_servername_callback(self.as_ptr(), Some(f)); Loading Loading @@ -1514,13 +1521,14 @@ impl SslContextBuilder { /// /// [`SSL_CTX_set_ex_data`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_ex_data.html pub fn set_ex_data<T>(&mut self, index: Index<SslContext, T>, data: T) { self.set_ex_data_inner(index, data); } fn set_ex_data_inner<T>(&mut self, index: Index<SslContext, T>, data: T) -> *mut c_void { unsafe { let data = Box::new(data); ffi::SSL_CTX_set_ex_data( self.as_ptr(), index.as_raw(), Box::into_raw(data) as *mut c_void, ); let data = Box::into_raw(Box::new(data)) as *mut c_void; ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); data } } Loading openssl/src/ssl/test.rs +40 −1 Original line number Diff line number Diff line Loading @@ -29,7 +29,7 @@ use ssl::{ }; #[cfg(any(ossl102, ossl110))] use x509::verify::X509CheckFlags; use x509::{X509Name, X509StoreContext, X509VerifyResult, X509}; use x509::{X509, X509Name, X509StoreContext, X509VerifyResult}; use std::net::UdpSocket; Loading Loading @@ -1754,3 +1754,42 @@ fn psk_ciphers() { assert!(CLIENT_CALLED.load(Ordering::SeqCst) && SERVER_CALLED.load(Ordering::SeqCst)); } #[test] fn sni_callback_swapped_ctx() { static CALLED_BACK: AtomicBool = ATOMIC_BOOL_INIT; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let guard = thread::spawn(move || { let stream = listener.accept().unwrap().0; let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_servername_callback(|_, _| { CALLED_BACK.store(true, Ordering::SeqCst); Ok(()) }); let mut ssl = Ssl::new(&ctx.build()).unwrap(); let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_certificate_file(&Path::new("test/cert.pem"), SslFiletype::PEM) .unwrap(); ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM) .unwrap(); ssl.set_ssl_context(&ctx.build()).unwrap(); let mut stream = ssl.accept(stream).unwrap(); stream.write_all(&[0]).unwrap(); }); let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); let ssl = Ssl::new(&ctx.build()).unwrap(); let mut stream = ssl.connect(stream).unwrap(); stream.read_exact(&mut [0]).unwrap(); assert!(CALLED_BACK.load(Ordering::SeqCst)); guard.join().unwrap(); } Loading
openssl-sys/src/lib.rs +4 −0 Original line number Diff line number Diff line Loading @@ -1564,6 +1564,10 @@ pub unsafe fn SSL_CTX_set_tlsext_servername_callback( SSL_CTX_callback_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_CB, cb) } pub unsafe fn SSL_CTX_set_tlsext_servername_arg(ctx: *mut SSL_CTX, arg: *mut c_void) -> c_long { SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg) } pub unsafe fn SSL_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long { SSL_ctrl( s, Loading
openssl/src/ssl/callbacks.rs +4 −6 Original line number Diff line number Diff line Loading @@ -109,7 +109,8 @@ where let ssl = SslRef::from_ptr_mut(ssl); let callback_idx = SslContext::cached_ex_index::<F>(); let callback = ssl.ssl_context() let callback = ssl .ssl_context() .ex_data(callback_idx) .expect("BUG: psk callback missing") as *const F; let identity = if identity != ptr::null() { Loading Loading @@ -152,16 +153,13 @@ where } } pub extern "C" fn raw_sni<F>(ssl: *mut ffi::SSL, al: *mut c_int, _arg: *mut c_void) -> c_int pub extern "C" fn raw_sni<F>(ssl: *mut ffi::SSL, al: *mut c_int, arg: *mut c_void) -> c_int where F: Fn(&mut SslRef, &mut SslAlert) -> Result<(), SniError> + 'static + Sync + Send, { unsafe { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl .ssl_context() .ex_data(SslContext::cached_ex_index::<F>()) .expect("BUG: sni callback missing") as *const F; let callback = arg as *const F; let mut alert = SslAlert(*al); let r = (*callback)(ssl, &mut alert); Loading
openssl/src/ssl/mod.rs +15 −7 Original line number Diff line number Diff line Loading @@ -658,7 +658,14 @@ impl SslContextBuilder { F: Fn(&mut SslRef, &mut SslAlert) -> Result<(), SniError> + 'static + Sync + Send, { unsafe { self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); // The SNI callback is somewhat unique in that the callback associated with the original // context associated with an SSL can be used even if the SSL's context has been swapped // out. When that happens, we wouldn't be able to look up the callback's state in the // context's ex data. Instead, pass the pointer directly as the servername arg. It's // still stored in ex data to manage the lifetime. let arg = self.set_ex_data_inner(SslContext::cached_ex_index::<F>(), callback); ffi::SSL_CTX_set_tlsext_servername_arg(self.as_ptr(), arg); let f: extern "C" fn(_, _, _) -> _ = raw_sni::<F>; let f: extern "C" fn() = mem::transmute(f); ffi::SSL_CTX_set_tlsext_servername_callback(self.as_ptr(), Some(f)); Loading Loading @@ -1514,13 +1521,14 @@ impl SslContextBuilder { /// /// [`SSL_CTX_set_ex_data`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_ex_data.html pub fn set_ex_data<T>(&mut self, index: Index<SslContext, T>, data: T) { self.set_ex_data_inner(index, data); } fn set_ex_data_inner<T>(&mut self, index: Index<SslContext, T>, data: T) -> *mut c_void { unsafe { let data = Box::new(data); ffi::SSL_CTX_set_ex_data( self.as_ptr(), index.as_raw(), Box::into_raw(data) as *mut c_void, ); let data = Box::into_raw(Box::new(data)) as *mut c_void; ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); data } } Loading
openssl/src/ssl/test.rs +40 −1 Original line number Diff line number Diff line Loading @@ -29,7 +29,7 @@ use ssl::{ }; #[cfg(any(ossl102, ossl110))] use x509::verify::X509CheckFlags; use x509::{X509Name, X509StoreContext, X509VerifyResult, X509}; use x509::{X509, X509Name, X509StoreContext, X509VerifyResult}; use std::net::UdpSocket; Loading Loading @@ -1754,3 +1754,42 @@ fn psk_ciphers() { assert!(CLIENT_CALLED.load(Ordering::SeqCst) && SERVER_CALLED.load(Ordering::SeqCst)); } #[test] fn sni_callback_swapped_ctx() { static CALLED_BACK: AtomicBool = ATOMIC_BOOL_INIT; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let guard = thread::spawn(move || { let stream = listener.accept().unwrap().0; let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_servername_callback(|_, _| { CALLED_BACK.store(true, Ordering::SeqCst); Ok(()) }); let mut ssl = Ssl::new(&ctx.build()).unwrap(); let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); ctx.set_certificate_file(&Path::new("test/cert.pem"), SslFiletype::PEM) .unwrap(); ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM) .unwrap(); ssl.set_ssl_context(&ctx.build()).unwrap(); let mut stream = ssl.accept(stream).unwrap(); stream.write_all(&[0]).unwrap(); }); let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); let ssl = Ssl::new(&ctx.build()).unwrap(); let mut stream = ssl.connect(stream).unwrap(); stream.read_exact(&mut [0]).unwrap(); assert!(CALLED_BACK.load(Ordering::SeqCst)); guard.join().unwrap(); }