Commit 99fdb2bd authored by Benjamin Saunders's avatar Benjamin Saunders
Browse files

Introduce SslStreamBuilder

parent bbb1cb61
Loading
Loading
Loading
Loading
+114 −34
Original line number Diff line number Diff line
@@ -1432,7 +1432,8 @@ impl SslContextBuilder {
        }
    }

    /// Sets the callback for generating an application cookie for stateless handshakes.
    /// Sets the callback for generating an application cookie for TLS1.3
    /// stateless handshakes.
    ///
    /// The callback will be called with the SSL context and a slice into which the cookie
    /// should be written. The callback should return the number of bytes written.
@@ -1454,7 +1455,8 @@ impl SslContextBuilder {
        }
    }

    /// Sets the callback for verifying an application cookie for stateless handshakes.
    /// Sets the callback for verifying an application cookie for TLS1.3
    /// stateless handshakes.
    ///
    /// The callback will be called with the SSL context and the cookie supplied by the
    /// client. It should return true if and only if the cookie is valid.
@@ -2632,22 +2634,7 @@ impl Ssl {
    where
        S: Read + Write,
    {
        let mut stream = SslStream::new_base(self, stream);
        let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) };
        if ret > 0 {
            Ok(stream)
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
        SslStreamBuilder::new(self, stream).connect()
    }

    /// Initiates a server-side TLS handshake.
@@ -2664,22 +2651,7 @@ impl Ssl {
    where
        S: Read + Write,
    {
        let mut stream = SslStream::new_base(self, stream);
        let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) };
        if ret > 0 {
            Ok(stream)
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
        SslStreamBuilder::new(self, stream).accept()
    }
}

@@ -2951,6 +2923,114 @@ impl<S: Read + Write> Write for SslStream<S> {
    }
}

/// A partially constructed `SslStream`, useful for unusual handshakes.
pub struct SslStreamBuilder<S> {
    inner: SslStream<S>
}

impl<S> SslStreamBuilder<S>
    where S: Read + Write
{
    /// Begin creating an `SslStream` atop `stream`
    pub fn new(ssl: Ssl, stream: S) -> Self {
        Self {
            inner: SslStream::new_base(ssl, stream),
        }
    }

    /// Perform a stateless server-side handshake
    ///
    /// Requires that cookie generation and verification callbacks were
    /// set on the SSL context.
    ///
    /// Returns `Ok(true)` if a complete ClientHello containing a valid cookie
    /// was read, in which case the handshake should be continued via
    /// `accept`. If a HelloRetryRequest containing a fresh cookie was
    /// transmitted, `Ok(false)` is returned instead. If the handshake cannot
    /// proceed at all, `Err` is returned.
    ///
    /// This corresponds to [`SSL_stateless`]
    ///
    /// [`SSL_stateless`]: https://www.openssl.org/docs/manmaster/man3/SSL_stateless.html
    #[cfg(ossl111)]
    pub fn stateless(&mut self) -> Result<bool, ErrorStack> {
        match unsafe { ffi::SSL_stateless(self.inner.ssl.as_ptr()) } {
            1 => Ok(true),
            0 => Ok(false),
            -1 => Err(ErrorStack::get()),
            _ => unreachable!(),
        }
    }

    /// See `Ssl::connect`
    pub fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
        let mut stream = self.inner;
        let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) };
        if ret > 0 {
            Ok(stream)
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
    }

    /// See `Ssl::accept`
    pub fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
        let mut stream = self.inner;
        let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) };
        if ret > 0 {
            Ok(stream)
        } else {
            let error = stream.make_error(ret);
            match error.code() {
                ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock(
                    MidHandshakeSslStream { stream, error },
                )),
                _ => Err(HandshakeError::Failure(MidHandshakeSslStream {
                    stream,
                    error,
                })),
            }
        }
    }

    // Future work: early IO methods
}

impl<S> SslStreamBuilder<S> {
    /// Returns a shared reference to the underlying stream.
    pub fn get_ref(&self) -> &S {
        unsafe {
            let bio = self.inner.ssl.get_raw_rbio();
            bio::get_ref(bio)
        }
    }

    /// Returns a mutable reference to the underlying stream.
    ///
    /// # Warning
    ///
    /// It is inadvisable to read from or write to the underlying stream as it
    /// will most likely corrupt the SSL session.
    pub fn get_mut(&mut self) -> &mut S {
        unsafe {
            let bio = self.inner.ssl.get_raw_rbio();
            bio::get_mut(bio)
        }
    }

    /// Returns a shared reference to the `Ssl` object associated with this builder.
    pub fn ssl(&self) -> &SslRef { &self.inner.ssl }
}

/// The result of a shutdown request.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ShutdownResult {