Commit 8f568970 authored by Steven Fackler's avatar Steven Fackler
Browse files

Implement read and write

parent 9ee6f1c5
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -55,6 +55,8 @@ extern {
    pub fn BIO_set_retry_read(b: *mut BIO);
    #[link_name = "BIO_set_retry_write_shim"]
    pub fn BIO_set_retry_write(b: *mut BIO);
    #[link_name = "BIO_flush"]
    pub fn BIO_flush(b: *mut BIO) -> c_long;
    pub fn SSL_CTX_set_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long;
    pub fn SSL_CTX_get_options_shim(ctx: *mut SSL_CTX) -> c_long;
    pub fn SSL_CTX_clear_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long;
+4 −0
Original line number Diff line number Diff line
@@ -103,6 +103,10 @@ void BIO_set_retry_write_shim(BIO *b) {
    BIO_set_retry_write(b);
}

long BIO_flush_shim(BIO *b) {
    return BIO_flush(b);
}

long SSL_CTX_set_options_shim(SSL_CTX *ctx, long options) {
    return SSL_CTX_set_options(ctx, options);
}
+130 −96
Original line number Diff line number Diff line
@@ -1350,102 +1350,6 @@ impl<S: Read+Write> Write for SslStream<S> {
    }
}

pub struct SslStreamNg<S> {
    ssl: Ssl,
    _method: Box<ffi::BIO_METHOD>, // :(
    _p: PhantomData<S>,
}

impl<S> Drop for SslStreamNg<S> {
    fn drop(&mut self) {
        unsafe {
            let _ = bio::take_stream::<S>(self.ssl.get_raw_rbio());
        }
    }
}

impl<S: Read + Write> SslStreamNg<S> {
    fn new_base(ssl: Ssl, stream: S) -> Result<Self, SslError> {
        unsafe {
            let (bio, method) = try!(bio::new(stream));
            ffi::SSL_set_bio(ssl.ssl, bio, bio);

            Ok(SslStreamNg {
                ssl: ssl,
                _method: method,
                _p: PhantomData,
            })
        }
    }

    /// Creates an SSL/TLS client operating over the provided stream.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> {
        let ssl = try!(ssl.into_ssl());
        let mut stream = try!(Self::new_base(ssl, stream));
        let ret = stream.ssl.connect();
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
        }
    }

    /// Creates an SSL/TLS server operating over the provided stream.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> {
        let ssl = try!(ssl.into_ssl());
        let mut stream = try!(Self::new_base(ssl, stream));
        let ret = stream.ssl.accept();
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
        }
    }

    pub fn get_ref(&self) -> &S {
        unsafe {
            let bio = self.ssl.get_raw_rbio();
            bio::get_ref(bio)
        }
    }

    pub fn mut_ref(&mut self) -> &mut S {
        unsafe {
            let bio = self.ssl.get_raw_rbio();
            bio::get_mut(bio)
        }
    }

    fn make_error(&mut self, ret: c_int) -> SslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => SslError::get(),
            LibSslError::ErrorSyscall => {
                let err = SslError::get();
                let count = match err {
                    SslError::OpenSslErrors(ref v) => v.len(),
                    _ => unreachable!(),
                };
                if count == 0 {
                    if ret == 0 {
                        SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                             "unexpected EOF observed"))
                    } else {
                        let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) };
                        SslError::StreamError(error.unwrap())
                    }
                } else {
                    err
                }
            }
            LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => {
                let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) };
                SslError::StreamError(error.unwrap())
            }
            err => panic!("unexpected error {:?} with ret {}", err, ret),
        }
    }
}

pub trait IntoSsl {
    fn into_ssl(self) -> Result<Ssl, SslError>;
}
@@ -1756,3 +1660,133 @@ impl<S: Read+Write> NonblockingSslStream<S> {
        }
    }
}

pub struct SslStreamNg<S> {
    ssl: Ssl,
    _method: Box<ffi::BIO_METHOD>, // :(
    _p: PhantomData<S>,
}

impl<S> Drop for SslStreamNg<S> {
    fn drop(&mut self) {
        unsafe {
            let _ = bio::take_stream::<S>(self.ssl.get_raw_rbio());
        }
    }
}

impl<S: Read + Write> SslStreamNg<S> {
    fn new_base(ssl: Ssl, stream: S) -> Result<Self, SslError> {
        unsafe {
            let (bio, method) = try!(bio::new(stream));
            ffi::SSL_set_bio(ssl.ssl, bio, bio);

            Ok(SslStreamNg {
                ssl: ssl,
                _method: method,
                _p: PhantomData,
            })
        }
    }

    /// Creates an SSL/TLS client operating over the provided stream.
    pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> {
        let ssl = try!(ssl.into_ssl());
        let mut stream = try!(Self::new_base(ssl, stream));
        let ret = stream.ssl.connect();
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
        }
    }

    /// Creates an SSL/TLS server operating over the provided stream.
    pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, SslError> {
        let ssl = try!(ssl.into_ssl());
        let mut stream = try!(Self::new_base(ssl, stream));
        let ret = stream.ssl.accept();
        if ret > 0 {
            Ok(stream)
        } else {
            Err(stream.make_error(ret))
        }
    }
}

impl<S> SslStreamNg<S> {
    pub fn get_ref(&self) -> &S {
        unsafe {
            let bio = self.ssl.get_raw_rbio();
            bio::get_ref(bio)
        }
    }

    pub fn get_mut(&mut self) -> &mut S {
        unsafe {
            let bio = self.ssl.get_raw_rbio();
            bio::get_mut(bio)
        }
    }

    fn make_error(&mut self, ret: c_int) -> SslError {
        match self.ssl.get_error(ret) {
            LibSslError::ErrorSsl => SslError::get(),
            LibSslError::ErrorSyscall => {
                let err = SslError::get();
                let count = match err {
                    SslError::OpenSslErrors(ref v) => v.len(),
                    _ => unreachable!(),
                };
                if count == 0 {
                    if ret == 0 {
                        SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
                                                             "unexpected EOF observed"))
                    } else {
                        let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) };
                        SslError::StreamError(error.unwrap())
                    }
                } else {
                    err
                }
            }
            LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => {
                let error = unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) };
                SslError::StreamError(error.unwrap())
            }
            err => panic!("unexpected error {:?} with ret {}", err, ret),
        }
    }
}

impl<S: Read> Read for SslStreamNg<S> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let ret = self.ssl.read(buf);
        if ret >= 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        }
    }
}

impl<S: Write> Write for SslStreamNg<S> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let ret = self.ssl.write(buf);
        if ret > 0 {
            return Ok(ret as usize);
        }

        match self.make_error(ret) {
            SslError::StreamError(e) => Err(e),
            e => Err(io::Error::new(io::ErrorKind::Other, e)),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        self.get_mut().flush()
    }
}
+12 −0
Original line number Diff line number Diff line
@@ -936,3 +936,15 @@ fn ng_connect() {
    let ctx = SslContext::new(Sslv23).unwrap();
    SslStreamNg::connect(&ctx, stream).unwrap();
}

#[test]
fn ng_get() {
    let (_s, stream) = Server::new();
    let ctx = SslContext::new(Sslv23).unwrap();
    let mut stream = SslStreamNg::connect(&ctx, stream).unwrap();
    stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
    let mut resp = String::new();
    stream.read_to_string(&mut resp).unwrap();
    assert!(resp.starts_with("HTTP/1.0 200"));
    assert!(resp.ends_with("</HTML>\r\n\r\n"));
}