Commit d04c4ef4 authored by Steven Fackler's avatar Steven Fackler
Browse files

Connect working

parent d4d5547a
Loading
Loading
Loading
Loading
+16 −2
Original line number Diff line number Diff line
#[doc(hidden)];

use std::libc::{c_int, c_void};
use std::libc::{c_int, c_long, c_void};

// openssl/ssl.h
pub type SSL_CTX = c_void;
pub type SSL_METHOD = c_void;
pub type SSL = c_void;
pub type BIO = c_void;
pub type BIO_METHOD = c_void;

pub static SSL_ERROR_NONE: c_int = 0;
pub static SSL_ERROR_SSL: c_int = 1;
pub static SSL_ERROR_WANT_READ: c_int = 2;
pub static SSL_ERROR_WANT_WRITE: c_int = 3;
pub static SSL_ERROR_WANT_X509_LOOKUP: c_int = 4;
pub static SSL_ERROR_SYSCALL: c_int = 5;
pub static SSL_ERROR_ZERO_RETURN: c_int = 6;
pub static SSL_ERROR_WANT_CONNECT: c_int = 7;
pub static SSL_ERROR_WANT_ACCEPT: c_int = 8;

#[link_args = "-lssl"]
extern "C" { }

@@ -22,8 +33,11 @@ externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL)
externfn!(fn SSL_free(ssl: *SSL))
externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO))
externfn!(fn SSL_set_connect_state(ssl: *SSL))
externfn!(fn SSL_do_handshake(ssl: *SSL))
externfn!(fn SSL_connect(ssl: *SSL) -> c_int)
externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int)

externfn!(fn BIO_s_mem() -> *BIO_METHOD)
externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO)
externfn!(fn BIO_free(a: *BIO) -> c_int)
externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int)
externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)
+92 −2
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@ use std::rt::io::{Stream, Decorator};
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::task;
use std::ptr;
use std::vec;
use std::libc::{c_int, c_void};

mod ffi;

@@ -68,6 +70,17 @@ impl Drop for Ssl {
    }
}

enum SslError {
    ErrorNone,
    ErrorSsl,
    ErrorWantRead,
    ErrorWantWrite,
    ErrorWantX509Lookup,
    ErrorZeroReturn,
    ErrorWantConnect,
    ErrorWantAccept,
}

impl Ssl {
    fn new(ctx: &SslCtx) -> Ssl {
        let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
@@ -83,6 +96,24 @@ impl Ssl {
    fn set_connect_state(&self) {
        unsafe { ffi::SSL_set_connect_state(self.ssl); }
    }

    fn connect(&self) -> int {
        unsafe { ffi::SSL_connect(self.ssl) as int }
    }

    fn get_error(&self, ret: int) -> SslError {
        match unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) } {
            ffi::SSL_ERROR_NONE => ErrorNone,
            ffi::SSL_ERROR_SSL => ErrorSsl,
            ffi::SSL_ERROR_WANT_READ => ErrorWantRead,
            ffi::SSL_ERROR_WANT_WRITE => ErrorWantWrite,
            ffi::SSL_ERROR_WANT_X509_LOOKUP => ErrorWantX509Lookup,
            ffi::SSL_ERROR_ZERO_RETURN => ErrorZeroReturn,
            ffi::SSL_ERROR_WANT_CONNECT => ErrorWantConnect,
            ffi::SSL_ERROR_WANT_ACCEPT => ErrorWantAccept,
            _ => unreachable!()
        }
    }
}

struct MemBio {
@@ -102,11 +133,34 @@ impl MemBio {

        MemBio { bio: bio }
    }

    fn write(&self, buf: &[u8]) {
        unsafe {
            let ret = ffi::BIO_write(self.bio,
                                     vec::raw::to_ptr(buf) as *c_void,
                                     buf.len() as c_int);
            if ret < 0 {
                fail2!("write returned {}", ret);
            }
        }
    }

    fn read(&self, buf: &[u8]) -> uint {
        unsafe {
            let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
                                    buf.len() as c_int);
            if ret < 0 {
                fail2!("read returned {}", ret);
            }
            ret as uint
        }
    }
}

pub struct SslStream<S> {
    priv ctx: SslCtx,
    priv ssl: Ssl,
    priv buf: ~[u8],
    priv rbio: MemBio,
    priv wbio: MemBio,
    priv stream: S
@@ -122,16 +176,52 @@ impl<S: Stream> SslStream<S> {
        ssl.set_bio(&rbio, &wbio);
        ssl.set_connect_state();

        let stream = SslStream {
        let mut stream = SslStream {
            ctx: ctx,
            ssl: ssl,
            buf: vec::from_elem(16 * 1024, 0u8),
            rbio: rbio,
            wbio: wbio,
            stream: stream
        }
        };

        stream.connect();

        stream
    }

    fn connect(&mut self) {
        info!("in connect");
        loop {
            let ret = self.ssl.connect();
            info2!("connect returned {}", ret);
            if ret == 1 {
                return;
            }

            match self.ssl.get_error(ret) {
                ErrorWantRead => {
                    info2!("want read");
                    self.flush();
                    match self.stream.read(self.buf) {
                        Some(len) => self.rbio.write(self.buf.slice_to(len)),
                        None => unreachable!()
                    }
                }
                ErrorWantWrite => {
                    info2!("want write");
                    self.flush();
                }
                _ => unreachable!()
            }
        }
    }

    fn flush(&mut self) {
        let len = self.wbio.read(self.buf);
        self.stream.write(self.buf.slice_to(len));
        self.stream.flush();
    }
}

impl<S: Stream> Decorator<S> for SslStream<S> {
+1 −1
Original line number Diff line number Diff line
@@ -12,5 +12,5 @@ fn test_new_ctx() {
#[test]
fn test_new_sslstream() {
    let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap());
    let stream = SslStream::new(SslCtx::new(Sslv23), stream);
    SslStream::new(SslCtx::new(Sslv23), stream);
}