Commit 49678805 authored by Steven Fackler's avatar Steven Fackler
Browse files

Properly propogate errors

parent f5f10dea
Loading
Loading
Loading
Loading
+18 −12
Original line number Diff line number Diff line
@@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT;
static mut VERIFY_IDX: c_int = -1;
static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex];

macro_rules! try_ssl(
    ($e:expr) => (
        match $e {
            Ok(ok) => ok,
            Err(err) => return Err(StreamError(err))
        }
    )
)

fn init() {
    unsafe {
        INIT.doit(|| {
@@ -480,14 +489,11 @@ impl<S: Stream> SslStream<S> {

            match self.ssl.get_error(ret) {
                ErrorWantRead => {
                    self.flush();
                    match self.stream.read(self.buf) {
                        Ok(len) =>
                            self.ssl.get_rbio().write(self.buf.slice_to(len)),
                        Err(err) => return Err(StreamError(err))
                    try_ssl!(self.flush());
                    let len = try_ssl!(self.stream.read(self.buf));
                    self.ssl.get_rbio().write(self.buf.slice_to(len));
                }
                }
                ErrorWantWrite => { self.flush(); }
                ErrorWantWrite => { try_ssl!(self.flush()) }
                ErrorZeroReturn => return Err(SslSessionClosed),
                ErrorSsl => return Err(SslError::get()),
                _ => unreachable!()
@@ -495,14 +501,14 @@ impl<S: Stream> SslStream<S> {
        }
    }

    fn write_through(&mut self) {
    fn write_through(&mut self) -> IoResult<()> {
        loop {
            // TODO propogate errors
            match self.ssl.get_wbio().read(self.buf) {
                Some(len) => self.stream.write(self.buf.slice_to(len)),
                Some(len) => try!(self.stream.write(self.buf.slice_to(len))),
                None => break
            };
        }
        Ok(())
    }
}

@@ -533,13 +539,13 @@ impl<S: Stream> Writer for SslStream<S> {
                Ok(len) => start += len as uint,
                _ => unreachable!()
            }
            self.write_through();
            try!(self.write_through());
        }
        Ok(())
    }

    fn flush(&mut self) -> IoResult<()> {
        self.write_through();
        try!(self.write_through());
        self.stream.flush()
    }
}
+6 −6
Original line number Diff line number Diff line
@@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() {
fn test_write() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
    stream.write("hello".as_bytes());
    stream.flush();
    stream.write(" there".as_bytes());
    stream.flush();
    stream.write("hello".as_bytes()).unwrap();
    stream.flush().unwrap();
    stream.write(" there".as_bytes()).unwrap();
    stream.flush().unwrap();
}

#[test]
fn test_read() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
    stream.write("GET /\r\n\r\n".as_bytes());
    stream.flush();
    stream.write("GET /\r\n\r\n".as_bytes()).unwrap();
    stream.flush().unwrap();
    let buf = stream.read_to_end().ok().expect("read error");
    print!("{}", str::from_utf8(buf));
}