Unverified Commit ad37e7e0 authored by Steven Fackler's avatar Steven Fackler Committed by GitHub
Browse files

Merge pull request #1160 from HyeonuPark/set-dtls-mtu

Add ssl::set_dtls_mtu_size(usize)
parents eda3ada7 4d5e5e17
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ pub const BIO_TYPE_NONE: c_int = 0;
pub const BIO_CTRL_EOF: c_int = 2;
pub const BIO_CTRL_INFO: c_int = 3;
pub const BIO_CTRL_FLUSH: c_int = 11;
pub const BIO_CTRL_DGRAM_QUERY_MTU: c_int = 40;
pub const BIO_C_SET_BUF_MEM_EOF_RETURN: c_int = 130;

extern "C" {
+14 −3
Original line number Diff line number Diff line
use ffi::{
    self, BIO_clear_retry_flags, BIO_new, BIO_set_retry_read, BIO_set_retry_write, BIO,
    BIO_CTRL_FLUSH,
    BIO_CTRL_FLUSH, BIO_CTRL_DGRAM_QUERY_MTU,
};
use libc::{c_char, c_int, c_long, c_void, strlen};
use std::any::Any;
@@ -18,6 +18,7 @@ pub struct StreamState<S> {
    pub stream: S,
    pub error: Option<io::Error>,
    pub panic: Option<Box<dyn Any + Send>>,
    pub dtls_mtu_size: c_long,
}

/// Safe wrapper for BIO_METHOD
@@ -39,6 +40,7 @@ pub fn new<S: Read + Write>(stream: S) -> Result<(*mut BIO, BioMethod), ErrorSta
        stream: stream,
        error: None,
        panic: None,
        dtls_mtu_size: 0,
    });

    unsafe {
@@ -69,6 +71,13 @@ pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S {
    &mut state(bio).stream
}

pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
    if mtu_size as u64 > c_long::max_value() as u64 {
        panic!("Given MTU size {} can't be represented in a positive `c_long` range")
    }
    state::<S>(bio).dtls_mtu_size = mtu_size as c_long;
}

unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> {
    &mut *(BIO_get_data(bio) as *mut _)
}
@@ -134,9 +143,9 @@ unsafe extern "C" fn ctrl<S: Write>(
    _num: c_long,
    _ptr: *mut c_void,
) -> c_long {
    if cmd == BIO_CTRL_FLUSH {
    let state = state::<S>(bio);

    if cmd == BIO_CTRL_FLUSH {
        match catch_unwind(AssertUnwindSafe(|| state.stream.flush())) {
            Ok(Ok(())) => 1,
            Ok(Err(err)) => {
@@ -148,6 +157,8 @@ unsafe extern "C" fn ctrl<S: Write>(
                0
            }
        }
    } else if cmd == BIO_CTRL_DGRAM_QUERY_MTU {
        state.dtls_mtu_size
    } else {
        0
    }
+14 −0
Original line number Diff line number Diff line
@@ -3804,6 +3804,20 @@ impl<S> SslStreamBuilder<S> {
    pub fn ssl(&self) -> &SslRef {
        &self.inner.ssl
    }

    /// Set the DTLS MTU size.
    ///
    /// It will be ignored if the value is smaller than the minimum packet size
    /// the DTLS protocol requires.
    ///
    /// # Panics
    /// This function panics if the given mtu size can't be represented in a positive `c_long` range
    pub fn set_dtls_mtu_size(&mut self, mtu_size: usize) {
        unsafe {
            let bio = self.inner.ssl.get_raw_rbio();
            bio::set_dtls_mtu_size::<S>(bio, mtu_size);
        }
    }
}

/// The result of a shutdown request.
+13 −5
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ use ssl::{ClientHelloResponse, ExtensionContext};
use ssl::{
    Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor,
    SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, SslMethod,
    SslOptions, SslSessionCacheMode, SslStream, SslVerifyMode, StatusType,
    SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, StatusType,
};
#[cfg(ossl102)]
use x509::store::X509StoreBuilder;
@@ -322,7 +322,9 @@ fn test_connect_with_srtp_ctx() {
        ctx.set_private_key_file(&Path::new("test/key.pem"), SslFiletype::PEM)
            .unwrap();
        let ssl = Ssl::new(&ctx.build()).unwrap();
        let mut stream = ssl.accept(stream).unwrap();
        let mut builder = SslStreamBuilder::new(ssl, stream);
        builder.set_dtls_mtu_size(1500);
        let mut stream = builder.accept().unwrap();

        let mut buf = [0; 60];
        stream
@@ -340,7 +342,9 @@ fn test_connect_with_srtp_ctx() {
    ctx.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32")
        .unwrap();
    let ssl = Ssl::new(&ctx.build()).unwrap();
    let mut stream = ssl.connect(stream).unwrap();
    let mut builder = SslStreamBuilder::new(ssl, stream);
    builder.set_dtls_mtu_size(1500);
    let mut stream = builder.connect().unwrap();

    let mut buf = [1; 60];
    {
@@ -390,7 +394,9 @@ fn test_connect_with_srtp_ssl() {
            "SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32",
            profilenames
        );
        let mut stream = ssl.accept(stream).unwrap();
        let mut builder = SslStreamBuilder::new(ssl, stream);
        builder.set_dtls_mtu_size(1500);
        let mut stream = builder.accept().unwrap();

        let mut buf = [0; 60];
        stream
@@ -408,7 +414,9 @@ fn test_connect_with_srtp_ssl() {
    let mut ssl = Ssl::new(&ctx.build()).unwrap();
    ssl.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32")
        .unwrap();
    let mut stream = ssl.connect(stream).unwrap();
    let mut builder = SslStreamBuilder::new(ssl, stream);
    builder.set_dtls_mtu_size(1500);
    let mut stream = builder.connect().unwrap();

    let mut buf = [1; 60];
    {