Commit 1396143c authored by Steven Fackler's avatar Steven Fackler
Browse files

Add get_shutdown and set_shutdown

parent a062c62c
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -1319,6 +1319,9 @@ pub const SSL_SESS_CACHE_NO_INTERNAL_STORE: c_long = 0x200;
pub const SSL_SESS_CACHE_NO_INTERNAL: c_long =
    SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | SSL_SESS_CACHE_NO_INTERNAL_STORE;

pub const SSL_SENT_SHUTDOWN: c_int = 1;
pub const SSL_RECEIVED_SHUTDOWN: c_int = 2;

pub const SSL3_VERSION: c_int = 0x300;
pub const TLS1_VERSION: c_int = 0x301;
pub const TLS1_1_VERSION: c_int = 0x302;
@@ -2476,6 +2479,8 @@ extern "C" {
    pub fn SSL_get0_param(ssl: *mut SSL) -> *mut X509_VERIFY_PARAM;
    pub fn SSL_get_verify_result(ssl: *const SSL) -> c_long;
    pub fn SSL_shutdown(ssl: *mut SSL) -> c_int;
    pub fn SSL_get_shutdown(ssl: *const SSL) -> c_int;
    pub fn SSL_set_shutdown(ss: *mut SSL, mode: c_int);
    pub fn SSL_get_certificate(ssl: *const SSL) -> *mut X509;
    #[cfg(not(ossl102))]
    pub fn SSL_get_privatekey(ssl: *mut SSL) -> *mut EVP_PKEY;
+53 −11
Original line number Diff line number Diff line
@@ -2734,8 +2734,7 @@ impl SslRef {
                self.as_ptr(),
                p as *mut c_uchar,
                response.len() as c_long,
            ) as c_int)
                .map(|_| ())
            ) as c_int).map(|_| ())
        }
    }

@@ -3016,6 +3015,30 @@ impl<S: Read + Write> SslStream<S> {
            n => Err(self.make_error(n)),
        }
    }

    /// Returns the session's shutdown state.
    ///
    /// This corresponds to [`SSL_get_shutdown`].
    ///
    /// [`SSL_get_shutdown`]: https://www.openssl.org/docs/man1.1.1/man3/SSL_set_shutdown.html
    pub fn get_shutdown(&mut self) -> ShutdownState {
        unsafe {
            let bits = ffi::SSL_get_shutdown(self.ssl.as_ptr());
            ShutdownState { bits }
        }
    }

    /// Sets the session's shutdown state.
    ///
    /// This can be used to tell OpenSSL that the session should be cached even if a full two-way
    /// shutdown was not completed.
    ///
    /// This corresponds to [`SSL_set_shutdown`].
    ///
    /// [`SSL_set_shutdown`]: https://www.openssl.org/docs/man1.1.1/man3/SSL_set_shutdown.html
    pub fn set_shutdown(&mut self, state: ShutdownState) {
        unsafe { ffi::SSL_set_shutdown(self.ssl.as_ptr(), state.bits()) }
    }
}

impl<S> SslStream<S> {
@@ -3188,9 +3211,12 @@ where
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
                        stream,
                        error,
                    }))
                }
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
@@ -3208,9 +3234,12 @@ where
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
                        stream,
                        error,
                    }))
                }
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
@@ -3234,9 +3263,12 @@ where
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
                    Err(HandshakeError::WouldBlock(MidHandshakeSslStream {
                        stream,
                        error,
                    }))
                }
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
@@ -3344,6 +3376,16 @@ pub enum ShutdownResult {
    Received,
}

bitflags! {
    /// The shutdown state of a session.
    pub struct ShutdownState: c_int {
        /// A close notify message has been sent to the peer.
        const SENT = ffi::SSL_SENT_SHUTDOWN;
        /// A close notify message has been received from the peer.
        const RECEIVED = ffi::SSL_RECEIVED_SHUTDOWN;
    }
}

cfg_if! {
    if #[cfg(any(ossl110, libressl273))] {
        use ffi::{SSL_CTX_up_ref, SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server};
+9 −2
Original line number Diff line number Diff line
@@ -22,8 +22,9 @@ use ssl;
#[cfg(any(ossl110, ossl111, libressl261))]
use ssl::SslVersion;
use ssl::{
    Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, Ssl, SslAcceptor, SslConnector,
    SslContext, SslFiletype, SslMethod, SslSessionCacheMode, SslStream, SslVerifyMode, StatusType,
    Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor,
    SslConnector, SslContext, SslFiletype, SslMethod, SslSessionCacheMode, SslStream,
    SslVerifyMode, StatusType,
};
#[cfg(any(ossl102, ossl110))]
use x509::verify::X509CheckFlags;
@@ -1013,8 +1014,14 @@ fn shutdown() {
    stream.read_exact(&mut buf).unwrap();
    assert_eq!(b"hello", &buf);

    assert_eq!(stream.get_shutdown(), ShutdownState::empty());
    assert_eq!(stream.shutdown().unwrap(), ShutdownResult::Sent);
    assert_eq!(stream.get_shutdown(), ShutdownState::SENT);
    assert_eq!(stream.shutdown().unwrap(), ShutdownResult::Received);
    assert_eq!(
        stream.get_shutdown(),
        ShutdownState::SENT | ShutdownState::RECEIVED
    );

    guard.join().unwrap();
}