Commit 302590c2 authored by Steven Fackler's avatar Steven Fackler
Browse files

Major rewrite for better error handling

parent a42d5261
Loading
Loading
Loading
Loading

error.rs

0 → 100644
+18 −0
Original line number Diff line number Diff line
use std::libc::c_ulong;

use super::ffi;

pub enum SslError {
    StreamEof,
    SslSessionClosed,
    UnknownError(c_ulong)
}

impl SslError {
    pub fn get() -> Option<SslError> {
        match unsafe { ffi::ERR_get_error() } {
            0 => None,
            err => Some(UnknownError(err))
        }
    }
}
+3 −0
Original line number Diff line number Diff line
@@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char,
externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL)
externfn!(fn SSL_free(ssl: *SSL))
externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO))
externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_set_connect_state(ssl: *SSL))
externfn!(fn SSL_connect(ssl: *SSL) -> c_int)
externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int)
@@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int)

externfn!(fn BIO_s_mem() -> *BIO_METHOD)
externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO)
externfn!(fn BIO_free_all(a: *BIO))
externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int)
externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)
+173 −160
Original line number Diff line number Diff line
use std::rt::io::{Reader, Writer, Stream, Decorator};
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::task;
use std::libc::{c_int, c_void};
use std::ptr;
use std::task;
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::rt::io::{Stream, Reader, Writer, Decorator};
use std::vec;
use std::libc::{c_int, c_void};

mod ffi;
use error::{SslError, SslSessionClosed, StreamEof};

pub mod error;

#[cfg(test)]
mod tests;

mod ffi;

static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL;

@@ -35,7 +39,7 @@ pub enum SslMethod {
}

impl SslMethod {
    unsafe fn to_fn(&self) -> *ffi::SSL_METHOD {
    unsafe fn to_raw(&self) -> *ffi::SSL_METHOD {
        match *self {
            Sslv2 => ffi::SSLv2_method(),
            Sslv3 => ffi::SSLv3_method(),
@@ -45,56 +49,60 @@ impl SslMethod {
    }
}

pub struct SslCtx {
pub enum SslVerifyMode {
    SslVerifyPeer = ffi::SSL_VERIFY_PEER,
    SslVerifyNone = ffi::SSL_VERIFY_NONE
}

pub struct SslContext {
    priv ctx: *ffi::SSL_CTX
}

impl Drop for SslCtx {
impl Drop for SslContext {
    fn drop(&mut self) {
        unsafe { ffi::SSL_CTX_free(self.ctx); }
        unsafe { ffi::SSL_CTX_free(self.ctx) }
    }
}

impl SslCtx {
    pub fn new(method: SslMethod) -> SslCtx {
impl SslContext {
    pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> {
        init();

        let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) };
        assert!(ctx != ptr::null());

        SslCtx {
            ctx: ctx
        }
        let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) };
        if ctx == ptr::null() {
            return Err(SslError::get().unwrap());
        }

    pub fn set_verify(&mut self, mode: SslVerifyMode) {
        unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) }
        Ok(SslContext { ctx: ctx })
    }

    pub fn set_verify_locations(&mut self, CAfile: &str) {
        do CAfile.with_c_str |CAfile| {
            unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile,
                                                        ptr::null()); }
    pub fn new(method: SslMethod) -> SslContext {
        match SslContext::try_new(method) {
            Ok(ctx) => ctx,
            Err(err) => fail!("Error creating SSL context: {:?}", err)
        }
    }

    // TODO: support callback (see SSL_CTX_set_ex_data)
    pub fn set_verify(&mut self, mode: SslVerifyMode) {
        unsafe {
            ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None);
        }
    }

pub enum SslVerifyMode {
    SslVerifyNone = ffi::SSL_VERIFY_NONE,
    SslVerifyPeer = ffi::SSL_VERIFY_PEER
    pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> {
        let ret = do file.with_c_str |file| {
            unsafe {
                ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null())
            }
        };

#[deriving(Eq, FromPrimitive)]
enum SslError {
    ErrorNone = ffi::SSL_ERROR_NONE,
    ErrorSsl = ffi::SSL_ERROR_SSL,
    ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
    ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE,
    ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP,
    ErrorSyscall = ffi::SSL_ERROR_SYSCALL,
    ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN,
    ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT,
    ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
        if ret == 0 {
            Some(SslError::get().unwrap())
        } else {
            None
        }
    }
}

struct Ssl {
@@ -103,138 +111,155 @@ struct Ssl {

impl Drop for Ssl {
    fn drop(&mut self) {
        unsafe { ffi::SSL_free(self.ssl); }
        unsafe { ffi::SSL_free(self.ssl) }
    }
}

impl Ssl {
    fn new(ctx: &SslCtx) -> Ssl {
    fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> {
        let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
        assert!(ssl != ptr::null());

        Ssl { ssl: ssl }
        if ssl == ptr::null() {
            return Err(SslError::get().unwrap());
        }
        let ssl = Ssl { ssl: ssl };

    fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) {
        unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); }
        let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
        if rbio == ptr::null() {
            return Err(SslError::get().unwrap());
        }

    fn set_connect_state(&self) {
        unsafe { ffi::SSL_set_connect_state(self.ssl); }
        let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
        if wbio == ptr::null() {
            unsafe { ffi::BIO_free_all(rbio) }
            return Err(SslError::get().unwrap());
        }

    fn connect(&self) -> int {
        unsafe { ffi::SSL_connect(self.ssl) as int }
        unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) }
        Ok(ssl)
    }

    fn get_error(&self, ret: int) -> SslError {
        let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) };
        match FromPrimitive::from_int(err as int) {
            Some(err) => err,
            None => fail2!("Unknown error {}", err)
        }
    }
    fn get_rbio<'a>(&'a self) -> MemBio<'a> {
        let bio = unsafe { ffi::SSL_get_rbio(self.ssl) };
        assert!(bio != ptr::null());

    fn read(&self, buf: &[u8]) -> int {
        unsafe {
            ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
                          buf.len() as c_int) as int
        MemBio {
            ssl: self,
            bio: bio
        }
    }

    fn write(&self, buf: &[u8]) -> int {
        unsafe {
            ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
                           buf.len() as c_int) as int
        }
    }
    fn get_wbio<'a>(&'a self) -> MemBio<'a> {
        let bio = unsafe { ffi::SSL_get_wbio(self.ssl) };
        assert!(bio != ptr::null());

    fn shutdown(&self) -> int {
        unsafe { ffi::SSL_shutdown(self.ssl) as int }
        MemBio {
            ssl: self,
            bio: bio
        }
    }

// BIOs are freed by SSL_free
struct MemBio {
    bio: *ffi::BIO
    fn connect(&self) -> c_int {
        unsafe { ffi::SSL_connect(self.ssl) }
    }

impl MemBio {
    fn new() -> MemBio {
        let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
        assert!(bio != ptr::null());
    fn read(&self, buf: &mut [u8]) -> c_int {
        unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
                               buf.len() as c_int) }
    }

        MemBio { bio: bio }
    fn write(&self, buf: &[u8]) -> c_int {
        unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
                                buf.len() as c_int) }
    }

    fn write(&self, buf: &[u8]) {
        unsafe {
            let ret = ffi::BIO_write(self.bio,
                                     vec::raw::to_ptr(buf) as *c_void,
                                     buf.len() as c_int);
            if ret < 0 {
                fail2!("write returned {}", ret);
    fn get_error(&self, ret: c_int) -> LibSslError {
        let err = unsafe { ffi::SSL_get_error(self.ssl, ret) };
        match FromPrimitive::from_int(err as int) {
            Some(err) => err,
            None => unreachable!()
        }
    }
}

    fn read(&self, buf: &[u8]) -> uint {
        unsafe {
            let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
                                    buf.len() as c_int);
#[deriving(FromPrimitive)]
enum LibSslError {
    ErrorNone = ffi::SSL_ERROR_NONE,
    ErrorSsl = ffi::SSL_ERROR_SSL,
    ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
    ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE,
    ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP,
    ErrorSyscall = ffi::SSL_ERROR_SYSCALL,
    ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN,
    ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT,
    ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
}

struct MemBio<'self> {
    ssl: &'self Ssl,
    bio: *ffi::BIO
}

impl<'self> MemBio<'self> {
    fn read(&self, buf: &mut [u8]) -> Option<uint> {
        let ret = unsafe {
            ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
                          buf.len() as c_int)
        };

        if ret < 0 {
                0
            None
        } else {
                ret as uint
            Some(ret as uint)
        }
    }

    fn write(&self, buf: &[u8]) {
        let ret = unsafe {
            ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void,
                           buf.len() as c_int)
        };
        assert_eq!(buf.len(), ret as uint);
    }
}

pub struct SslStream<S> {
    priv ctx: SslCtx,
    priv stream: S,
    priv ssl: Ssl,
    priv buf: ~[u8],
    priv rbio: MemBio,
    priv wbio: MemBio,
    priv stream: S
    priv buf: ~[u8]
}

impl<S: Stream> SslStream<S> {
    pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> {
        let ssl = Ssl::new(&ctx);

        let rbio = MemBio::new();
        let wbio = MemBio::new();

        ssl.set_bio(&rbio, &wbio);
        ssl.set_connect_state();
    pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>,
                                                          SslError> {
        let ssl = match Ssl::try_new(ctx) {
            Ok(ssl) => ssl,
            Err(err) => return Err(err)
        };

        let mut stream = SslStream {
            ctx: ctx,
        let mut ssl = SslStream {
            stream: stream,
            ssl: ssl,
            // Max record size for SSLv3/TLSv1 is 16k
            buf: vec::from_elem(16 * 1024, 0u8),
            rbio: rbio,
            wbio: wbio,
            stream: stream
            // Maximum TLS record size is 16k
            buf: vec::from_elem(16 * 1024, 0u8)
        };

        let ret = do stream.in_retry_wrapper |ssl| {
            ssl.ssl.connect()
        };
        match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) {
            Ok(_) => Ok(ssl),
            Err(err) => Err(err)
        }
    }

        match ret {
            Ok(_) => Ok(stream),
            // FIXME
            Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint })
    pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> {
        match SslStream::try_new(ctx, stream) {
            Ok(stream) => stream,
            Err(err) => fail!("Error creating SSL stream: {:?}", err)
        }
    }

    fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int)
                        -> Result<int, SslError> {
    fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int)
            -> Result<c_int, SslError> {
        loop {
            let ret = blk(self);
            let ret = blk(&self.ssl);
            if ret > 0 {
                return Ok(ret);
            }
@@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> {
                ErrorWantRead => {
                    self.flush();
                    match self.stream.read(self.buf) {
                        Some(len) => self.rbio.write(self.buf.slice_to(len)),
                        None => return Err(ErrorZeroReturn) // FIXME
                        Some(len) =>
                            self.ssl.get_rbio().write(self.buf.slice_to(len)),
                        None => return Err(StreamEof)
                    }
                }
                ErrorWantWrite => self.flush(),
                err => return Err(err)
                ErrorZeroReturn => return Err(SslSessionClosed),
                ErrorSsl => return Err(SslError::get().unwrap()),
                _ => unreachable!()
            }
        }
    }

    fn write_through(&mut self) {
        loop {
            let len = self.wbio.read(self.buf);
            if len == 0 {
                return;
            }
            self.stream.write(self.buf.slice_to(len));
        }
    }

    pub fn shutdown(&mut self) {
        loop {
            let ret = do self.in_retry_wrapper |ssl| {
                ssl.ssl.shutdown()
            };

            if ret != Ok(0) {
                break;
            match self.ssl.get_wbio().read(self.buf) {
                Some(len) => self.stream.write(self.buf.slice_to(len)),
                None => break
            }
        }
    }
@@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> {

impl<S: Stream> Reader for SslStream<S> {
    fn read(&mut self, buf: &mut [u8]) -> Option<uint> {
        let ret = do self.in_retry_wrapper |ssl| {
            ssl.ssl.read(buf)
        };

        match ret {
            Ok(num) => Some(num as uint),
            Err(_) => None
        match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
            Ok(len) => Some(len as uint),
            Err(StreamEof) | Err(SslSessionClosed) => None,
            _ => unreachable!()
        }
    }

@@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> {

impl<S: Stream> Writer for SslStream<S> {
    fn write(&mut self, buf: &[u8]) {
        let mut start = 0;
        while start < buf.len() {
            let ret = do self.in_retry_wrapper |ssl| {
            ssl.ssl.write(buf)
                ssl.write(buf.slice_from(start))
            };

            match ret {
            Ok(_) => (),
            Err(err) => fail2!("Write error: {:?}", err)
                Ok(len) => start += len as uint,
                _ => unreachable!()
            }

            self.write_through();
        }
    }

    fn flush(&mut self) {
        self.write_through();
        self.stream.flush();
        self.stream.flush()
    }
}

impl<S: Stream> Decorator<S> for SslStream<S> {
impl<S> Decorator<S> for SslStream<S> {
    fn inner(self) -> S {
        self.stream
    }
+11 −12
Original line number Diff line number Diff line
@@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil;
use std::rt::io::net::tcp::TcpStream;
use std::str;

use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer};
use super::{Sslv23, SslContext, SslStream, SslVerifyPeer};

#[test]
fn test_new_ctx() {
    SslCtx::new(Sslv23);
    SslContext::new(Sslv23);
}

#[test]
fn test_new_sslstream() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
    SslStream::new(&SslContext::new(Sslv23), stream);
}

#[test]
fn test_verify_untrusted() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut ctx = SslCtx::new(Sslv23);
    let mut ctx = SslContext::new(Sslv23);
    ctx.set_verify(SslVerifyPeer);
    match SslStream::new(ctx, stream) {
    match SslStream::try_new(&ctx, stream) {
        Ok(_) => fail2!("expected failure"),
        Err(err) => println!("error {}", err)
        Err(err) => println!("error {:?}", err)
    }
}

#[test]
fn test_verify_trusted() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut ctx = SslCtx::new(Sslv23);
    let mut ctx = SslContext::new(Sslv23);
    ctx.set_verify(SslVerifyPeer);
    ctx.set_verify_locations("cert.pem");
    match SslStream::new(ctx, stream) {
    assert!(ctx.set_CA_file("cert.pem").is_none());
    match SslStream::try_new(&ctx, stream) {
        Ok(_) => (),
        Err(err) => fail2!("Expected success, got {:?}", err)
    }
@@ -42,18 +42,17 @@ fn test_verify_trusted() {
#[test]
fn test_write() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
    let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
    stream.write("hello".as_bytes());
    stream.flush();
    stream.write(" there".as_bytes());
    stream.flush();
    stream.shutdown();
}

#[test]
fn test_read() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
    let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
    let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
    stream.write("GET /\r\n\r\n".as_bytes());
    stream.flush();
    let buf = stream.read_to_end();