Loading error.rs 0 → 100644 +18 −0 Original line number Diff line number Diff line use std::libc::c_ulong; use super::ffi; pub enum SslError { StreamEof, SslSessionClosed, UnknownError(c_ulong) } impl SslError { pub fn get() -> Option<SslError> { match unsafe { ffi::ERR_get_error() } { 0 => None, err => Some(UnknownError(err)) } } } ffi.rs +3 −0 Original line number Diff line number Diff line Loading @@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char, externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL) externfn!(fn SSL_free(ssl: *SSL)) externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO)) externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO) externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO) externfn!(fn SSL_set_connect_state(ssl: *SSL)) externfn!(fn SSL_connect(ssl: *SSL) -> c_int) externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int) Loading @@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int) externfn!(fn BIO_s_mem() -> *BIO_METHOD) externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO) externfn!(fn BIO_free_all(a: *BIO)) externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int) externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int) lib.rs +173 −160 Original line number Diff line number Diff line use std::rt::io::{Reader, Writer, Stream, Decorator}; use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; use std::task; use std::libc::{c_int, c_void}; use std::ptr; use std::task; use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; use std::rt::io::{Stream, Reader, Writer, Decorator}; use std::vec; use std::libc::{c_int, c_void}; mod ffi; use error::{SslError, SslSessionClosed, StreamEof}; pub mod error; #[cfg(test)] mod tests; mod ffi; static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL; static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL; Loading @@ -35,7 +39,7 @@ pub enum SslMethod { } impl SslMethod { unsafe fn to_fn(&self) -> *ffi::SSL_METHOD { unsafe fn to_raw(&self) -> *ffi::SSL_METHOD { match *self { Sslv2 => ffi::SSLv2_method(), Sslv3 => ffi::SSLv3_method(), Loading @@ -45,56 +49,60 @@ impl SslMethod { } } pub struct SslCtx { pub enum SslVerifyMode { SslVerifyPeer = ffi::SSL_VERIFY_PEER, SslVerifyNone = ffi::SSL_VERIFY_NONE } pub struct SslContext { priv ctx: *ffi::SSL_CTX } impl Drop for SslCtx { impl Drop for SslContext { fn drop(&mut self) { unsafe { ffi::SSL_CTX_free(self.ctx); } unsafe { ffi::SSL_CTX_free(self.ctx) } } } impl SslCtx { pub fn new(method: SslMethod) -> SslCtx { impl SslContext { pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> { init(); let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) }; assert!(ctx != ptr::null()); SslCtx { ctx: ctx } let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) }; if ctx == ptr::null() { return Err(SslError::get().unwrap()); } pub fn set_verify(&mut self, mode: SslVerifyMode) { unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) } Ok(SslContext { ctx: ctx }) } pub fn set_verify_locations(&mut self, CAfile: &str) { do CAfile.with_c_str |CAfile| { unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile, ptr::null()); } pub fn new(method: SslMethod) -> SslContext { match SslContext::try_new(method) { Ok(ctx) => ctx, Err(err) => fail!("Error creating SSL context: {:?}", err) } } // TODO: support callback (see SSL_CTX_set_ex_data) pub fn set_verify(&mut self, mode: SslVerifyMode) { unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None); } } pub enum SslVerifyMode { SslVerifyNone = ffi::SSL_VERIFY_NONE, SslVerifyPeer = ffi::SSL_VERIFY_PEER pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> { let ret = do file.with_c_str |file| { unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null()) } }; #[deriving(Eq, FromPrimitive)] enum SslError { ErrorNone = ffi::SSL_ERROR_NONE, ErrorSsl = ffi::SSL_ERROR_SSL, ErrorWantRead = ffi::SSL_ERROR_WANT_READ, ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE, ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP, ErrorSyscall = ffi::SSL_ERROR_SYSCALL, ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN, ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT, ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, if ret == 0 { Some(SslError::get().unwrap()) } else { None } } } struct Ssl { Loading @@ -103,138 +111,155 @@ struct Ssl { impl Drop for Ssl { fn drop(&mut self) { unsafe { ffi::SSL_free(self.ssl); } unsafe { ffi::SSL_free(self.ssl) } } } impl Ssl { fn new(ctx: &SslCtx) -> Ssl { fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> { let ssl = unsafe { ffi::SSL_new(ctx.ctx) }; assert!(ssl != ptr::null()); Ssl { ssl: ssl } if ssl == ptr::null() { return Err(SslError::get().unwrap()); } let ssl = Ssl { ssl: ssl }; fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) { unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); } let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; if rbio == ptr::null() { return Err(SslError::get().unwrap()); } fn set_connect_state(&self) { unsafe { ffi::SSL_set_connect_state(self.ssl); } let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; if wbio == ptr::null() { unsafe { ffi::BIO_free_all(rbio) } return Err(SslError::get().unwrap()); } fn connect(&self) -> int { unsafe { ffi::SSL_connect(self.ssl) as int } unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) } Ok(ssl) } fn get_error(&self, ret: int) -> SslError { let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) }; match FromPrimitive::from_int(err as int) { Some(err) => err, None => fail2!("Unknown error {}", err) } } fn get_rbio<'a>(&'a self) -> MemBio<'a> { let bio = unsafe { ffi::SSL_get_rbio(self.ssl) }; assert!(bio != ptr::null()); fn read(&self, buf: &[u8]) -> int { unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) as int MemBio { ssl: self, bio: bio } } fn write(&self, buf: &[u8]) -> int { unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) as int } } fn get_wbio<'a>(&'a self) -> MemBio<'a> { let bio = unsafe { ffi::SSL_get_wbio(self.ssl) }; assert!(bio != ptr::null()); fn shutdown(&self) -> int { unsafe { ffi::SSL_shutdown(self.ssl) as int } MemBio { ssl: self, bio: bio } } // BIOs are freed by SSL_free struct MemBio { bio: *ffi::BIO fn connect(&self) -> c_int { unsafe { ffi::SSL_connect(self.ssl) } } impl MemBio { fn new() -> MemBio { let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; assert!(bio != ptr::null()); fn read(&self, buf: &mut [u8]) -> c_int { unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) } } MemBio { bio: bio } fn write(&self, buf: &[u8]) -> c_int { unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) } } fn write(&self, buf: &[u8]) { unsafe { let ret = ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int); if ret < 0 { fail2!("write returned {}", ret); fn get_error(&self, ret: c_int) -> LibSslError { let err = unsafe { ffi::SSL_get_error(self.ssl, ret) }; match FromPrimitive::from_int(err as int) { Some(err) => err, None => unreachable!() } } } fn read(&self, buf: &[u8]) -> uint { unsafe { let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int); #[deriving(FromPrimitive)] enum LibSslError { ErrorNone = ffi::SSL_ERROR_NONE, ErrorSsl = ffi::SSL_ERROR_SSL, ErrorWantRead = ffi::SSL_ERROR_WANT_READ, ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE, ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP, ErrorSyscall = ffi::SSL_ERROR_SYSCALL, ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN, ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT, ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, } struct MemBio<'self> { ssl: &'self Ssl, bio: *ffi::BIO } impl<'self> MemBio<'self> { fn read(&self, buf: &mut [u8]) -> Option<uint> { let ret = unsafe { ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) }; if ret < 0 { 0 None } else { ret as uint Some(ret as uint) } } fn write(&self, buf: &[u8]) { let ret = unsafe { ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) }; assert_eq!(buf.len(), ret as uint); } } pub struct SslStream<S> { priv ctx: SslCtx, priv stream: S, priv ssl: Ssl, priv buf: ~[u8], priv rbio: MemBio, priv wbio: MemBio, priv stream: S priv buf: ~[u8] } impl<S: Stream> SslStream<S> { pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> { let ssl = Ssl::new(&ctx); let rbio = MemBio::new(); let wbio = MemBio::new(); ssl.set_bio(&rbio, &wbio); ssl.set_connect_state(); pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> { let ssl = match Ssl::try_new(ctx) { Ok(ssl) => ssl, Err(err) => return Err(err) }; let mut stream = SslStream { ctx: ctx, let mut ssl = SslStream { stream: stream, ssl: ssl, // Max record size for SSLv3/TLSv1 is 16k buf: vec::from_elem(16 * 1024, 0u8), rbio: rbio, wbio: wbio, stream: stream // Maximum TLS record size is 16k buf: vec::from_elem(16 * 1024, 0u8) }; let ret = do stream.in_retry_wrapper |ssl| { ssl.ssl.connect() }; match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) { Ok(_) => Ok(ssl), Err(err) => Err(err) } } match ret { Ok(_) => Ok(stream), // FIXME Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint }) pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> { match SslStream::try_new(ctx, stream) { Ok(stream) => stream, Err(err) => fail!("Error creating SSL stream: {:?}", err) } } fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int) -> Result<int, SslError> { fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int) -> Result<c_int, SslError> { loop { let ret = blk(self); let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } Loading @@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> { ErrorWantRead => { self.flush(); match self.stream.read(self.buf) { Some(len) => self.rbio.write(self.buf.slice_to(len)), None => return Err(ErrorZeroReturn) // FIXME Some(len) => self.ssl.get_rbio().write(self.buf.slice_to(len)), None => return Err(StreamEof) } } ErrorWantWrite => self.flush(), err => return Err(err) ErrorZeroReturn => return Err(SslSessionClosed), ErrorSsl => return Err(SslError::get().unwrap()), _ => unreachable!() } } } fn write_through(&mut self) { loop { let len = self.wbio.read(self.buf); if len == 0 { return; } self.stream.write(self.buf.slice_to(len)); } } pub fn shutdown(&mut self) { loop { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.shutdown() }; if ret != Ok(0) { break; match self.ssl.get_wbio().read(self.buf) { Some(len) => self.stream.write(self.buf.slice_to(len)), None => break } } } Loading @@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> { impl<S: Stream> Reader for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> Option<uint> { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.read(buf) }; match ret { Ok(num) => Some(num as uint), Err(_) => None match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Some(len as uint), Err(StreamEof) | Err(SslSessionClosed) => None, _ => unreachable!() } } Loading @@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> { impl<S: Stream> Writer for SslStream<S> { fn write(&mut self, buf: &[u8]) { let mut start = 0; while start < buf.len() { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.write(buf) ssl.write(buf.slice_from(start)) }; match ret { Ok(_) => (), Err(err) => fail2!("Write error: {:?}", err) Ok(len) => start += len as uint, _ => unreachable!() } self.write_through(); } } fn flush(&mut self) { self.write_through(); self.stream.flush(); self.stream.flush() } } impl<S: Stream> Decorator<S> for SslStream<S> { impl<S> Decorator<S> for SslStream<S> { fn inner(self) -> S { self.stream } Loading tests.rs +11 −12 Original line number Diff line number Diff line Loading @@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil; use std::rt::io::net::tcp::TcpStream; use std::str; use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer}; use super::{Sslv23, SslContext, SslStream, SslVerifyPeer}; #[test] fn test_new_ctx() { SslCtx::new(Sslv23); SslContext::new(Sslv23); } #[test] fn test_new_sslstream() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); SslStream::new(&SslContext::new(Sslv23), stream); } #[test] fn test_verify_untrusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); match SslStream::new(ctx, stream) { match SslStream::try_new(&ctx, stream) { Ok(_) => fail2!("expected failure"), Err(err) => println!("error {}", err) Err(err) => println!("error {:?}", err) } } #[test] fn test_verify_trusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); ctx.set_verify_locations("cert.pem"); match SslStream::new(ctx, stream) { assert!(ctx.set_CA_file("cert.pem").is_none()); match SslStream::try_new(&ctx, stream) { Ok(_) => (), Err(err) => fail2!("Expected success, got {:?}", err) } Loading @@ -42,18 +42,17 @@ fn test_verify_trusted() { #[test] fn test_write() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("hello".as_bytes()); stream.flush(); stream.write(" there".as_bytes()); stream.flush(); stream.shutdown(); } #[test] fn test_read() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("GET /\r\n\r\n".as_bytes()); stream.flush(); let buf = stream.read_to_end(); Loading Loading
error.rs 0 → 100644 +18 −0 Original line number Diff line number Diff line use std::libc::c_ulong; use super::ffi; pub enum SslError { StreamEof, SslSessionClosed, UnknownError(c_ulong) } impl SslError { pub fn get() -> Option<SslError> { match unsafe { ffi::ERR_get_error() } { 0 => None, err => Some(UnknownError(err)) } } }
ffi.rs +3 −0 Original line number Diff line number Diff line Loading @@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char, externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL) externfn!(fn SSL_free(ssl: *SSL)) externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO)) externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO) externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO) externfn!(fn SSL_set_connect_state(ssl: *SSL)) externfn!(fn SSL_connect(ssl: *SSL) -> c_int) externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int) Loading @@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int) externfn!(fn BIO_s_mem() -> *BIO_METHOD) externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO) externfn!(fn BIO_free_all(a: *BIO)) externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int) externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)
lib.rs +173 −160 Original line number Diff line number Diff line use std::rt::io::{Reader, Writer, Stream, Decorator}; use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; use std::task; use std::libc::{c_int, c_void}; use std::ptr; use std::task; use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; use std::rt::io::{Stream, Reader, Writer, Decorator}; use std::vec; use std::libc::{c_int, c_void}; mod ffi; use error::{SslError, SslSessionClosed, StreamEof}; pub mod error; #[cfg(test)] mod tests; mod ffi; static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL; static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL; Loading @@ -35,7 +39,7 @@ pub enum SslMethod { } impl SslMethod { unsafe fn to_fn(&self) -> *ffi::SSL_METHOD { unsafe fn to_raw(&self) -> *ffi::SSL_METHOD { match *self { Sslv2 => ffi::SSLv2_method(), Sslv3 => ffi::SSLv3_method(), Loading @@ -45,56 +49,60 @@ impl SslMethod { } } pub struct SslCtx { pub enum SslVerifyMode { SslVerifyPeer = ffi::SSL_VERIFY_PEER, SslVerifyNone = ffi::SSL_VERIFY_NONE } pub struct SslContext { priv ctx: *ffi::SSL_CTX } impl Drop for SslCtx { impl Drop for SslContext { fn drop(&mut self) { unsafe { ffi::SSL_CTX_free(self.ctx); } unsafe { ffi::SSL_CTX_free(self.ctx) } } } impl SslCtx { pub fn new(method: SslMethod) -> SslCtx { impl SslContext { pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> { init(); let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) }; assert!(ctx != ptr::null()); SslCtx { ctx: ctx } let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) }; if ctx == ptr::null() { return Err(SslError::get().unwrap()); } pub fn set_verify(&mut self, mode: SslVerifyMode) { unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) } Ok(SslContext { ctx: ctx }) } pub fn set_verify_locations(&mut self, CAfile: &str) { do CAfile.with_c_str |CAfile| { unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile, ptr::null()); } pub fn new(method: SslMethod) -> SslContext { match SslContext::try_new(method) { Ok(ctx) => ctx, Err(err) => fail!("Error creating SSL context: {:?}", err) } } // TODO: support callback (see SSL_CTX_set_ex_data) pub fn set_verify(&mut self, mode: SslVerifyMode) { unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None); } } pub enum SslVerifyMode { SslVerifyNone = ffi::SSL_VERIFY_NONE, SslVerifyPeer = ffi::SSL_VERIFY_PEER pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> { let ret = do file.with_c_str |file| { unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null()) } }; #[deriving(Eq, FromPrimitive)] enum SslError { ErrorNone = ffi::SSL_ERROR_NONE, ErrorSsl = ffi::SSL_ERROR_SSL, ErrorWantRead = ffi::SSL_ERROR_WANT_READ, ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE, ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP, ErrorSyscall = ffi::SSL_ERROR_SYSCALL, ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN, ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT, ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, if ret == 0 { Some(SslError::get().unwrap()) } else { None } } } struct Ssl { Loading @@ -103,138 +111,155 @@ struct Ssl { impl Drop for Ssl { fn drop(&mut self) { unsafe { ffi::SSL_free(self.ssl); } unsafe { ffi::SSL_free(self.ssl) } } } impl Ssl { fn new(ctx: &SslCtx) -> Ssl { fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> { let ssl = unsafe { ffi::SSL_new(ctx.ctx) }; assert!(ssl != ptr::null()); Ssl { ssl: ssl } if ssl == ptr::null() { return Err(SslError::get().unwrap()); } let ssl = Ssl { ssl: ssl }; fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) { unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); } let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; if rbio == ptr::null() { return Err(SslError::get().unwrap()); } fn set_connect_state(&self) { unsafe { ffi::SSL_set_connect_state(self.ssl); } let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; if wbio == ptr::null() { unsafe { ffi::BIO_free_all(rbio) } return Err(SslError::get().unwrap()); } fn connect(&self) -> int { unsafe { ffi::SSL_connect(self.ssl) as int } unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) } Ok(ssl) } fn get_error(&self, ret: int) -> SslError { let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) }; match FromPrimitive::from_int(err as int) { Some(err) => err, None => fail2!("Unknown error {}", err) } } fn get_rbio<'a>(&'a self) -> MemBio<'a> { let bio = unsafe { ffi::SSL_get_rbio(self.ssl) }; assert!(bio != ptr::null()); fn read(&self, buf: &[u8]) -> int { unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) as int MemBio { ssl: self, bio: bio } } fn write(&self, buf: &[u8]) -> int { unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) as int } } fn get_wbio<'a>(&'a self) -> MemBio<'a> { let bio = unsafe { ffi::SSL_get_wbio(self.ssl) }; assert!(bio != ptr::null()); fn shutdown(&self) -> int { unsafe { ffi::SSL_shutdown(self.ssl) as int } MemBio { ssl: self, bio: bio } } // BIOs are freed by SSL_free struct MemBio { bio: *ffi::BIO fn connect(&self) -> c_int { unsafe { ffi::SSL_connect(self.ssl) } } impl MemBio { fn new() -> MemBio { let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; assert!(bio != ptr::null()); fn read(&self, buf: &mut [u8]) -> c_int { unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) } } MemBio { bio: bio } fn write(&self, buf: &[u8]) -> c_int { unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) } } fn write(&self, buf: &[u8]) { unsafe { let ret = ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int); if ret < 0 { fail2!("write returned {}", ret); fn get_error(&self, ret: c_int) -> LibSslError { let err = unsafe { ffi::SSL_get_error(self.ssl, ret) }; match FromPrimitive::from_int(err as int) { Some(err) => err, None => unreachable!() } } } fn read(&self, buf: &[u8]) -> uint { unsafe { let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int); #[deriving(FromPrimitive)] enum LibSslError { ErrorNone = ffi::SSL_ERROR_NONE, ErrorSsl = ffi::SSL_ERROR_SSL, ErrorWantRead = ffi::SSL_ERROR_WANT_READ, ErrorWantWrite = ffi::SSL_ERROR_WANT_WRITE, ErrorWantX509Lookup = ffi::SSL_ERROR_WANT_X509_LOOKUP, ErrorSyscall = ffi::SSL_ERROR_SYSCALL, ErrorZeroReturn = ffi::SSL_ERROR_ZERO_RETURN, ErrorWantConnect = ffi::SSL_ERROR_WANT_CONNECT, ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, } struct MemBio<'self> { ssl: &'self Ssl, bio: *ffi::BIO } impl<'self> MemBio<'self> { fn read(&self, buf: &mut [u8]) -> Option<uint> { let ret = unsafe { ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) }; if ret < 0 { 0 None } else { ret as uint Some(ret as uint) } } fn write(&self, buf: &[u8]) { let ret = unsafe { ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int) }; assert_eq!(buf.len(), ret as uint); } } pub struct SslStream<S> { priv ctx: SslCtx, priv stream: S, priv ssl: Ssl, priv buf: ~[u8], priv rbio: MemBio, priv wbio: MemBio, priv stream: S priv buf: ~[u8] } impl<S: Stream> SslStream<S> { pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> { let ssl = Ssl::new(&ctx); let rbio = MemBio::new(); let wbio = MemBio::new(); ssl.set_bio(&rbio, &wbio); ssl.set_connect_state(); pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> { let ssl = match Ssl::try_new(ctx) { Ok(ssl) => ssl, Err(err) => return Err(err) }; let mut stream = SslStream { ctx: ctx, let mut ssl = SslStream { stream: stream, ssl: ssl, // Max record size for SSLv3/TLSv1 is 16k buf: vec::from_elem(16 * 1024, 0u8), rbio: rbio, wbio: wbio, stream: stream // Maximum TLS record size is 16k buf: vec::from_elem(16 * 1024, 0u8) }; let ret = do stream.in_retry_wrapper |ssl| { ssl.ssl.connect() }; match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) { Ok(_) => Ok(ssl), Err(err) => Err(err) } } match ret { Ok(_) => Ok(stream), // FIXME Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint }) pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> { match SslStream::try_new(ctx, stream) { Ok(stream) => stream, Err(err) => fail!("Error creating SSL stream: {:?}", err) } } fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int) -> Result<int, SslError> { fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int) -> Result<c_int, SslError> { loop { let ret = blk(self); let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } Loading @@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> { ErrorWantRead => { self.flush(); match self.stream.read(self.buf) { Some(len) => self.rbio.write(self.buf.slice_to(len)), None => return Err(ErrorZeroReturn) // FIXME Some(len) => self.ssl.get_rbio().write(self.buf.slice_to(len)), None => return Err(StreamEof) } } ErrorWantWrite => self.flush(), err => return Err(err) ErrorZeroReturn => return Err(SslSessionClosed), ErrorSsl => return Err(SslError::get().unwrap()), _ => unreachable!() } } } fn write_through(&mut self) { loop { let len = self.wbio.read(self.buf); if len == 0 { return; } self.stream.write(self.buf.slice_to(len)); } } pub fn shutdown(&mut self) { loop { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.shutdown() }; if ret != Ok(0) { break; match self.ssl.get_wbio().read(self.buf) { Some(len) => self.stream.write(self.buf.slice_to(len)), None => break } } } Loading @@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> { impl<S: Stream> Reader for SslStream<S> { fn read(&mut self, buf: &mut [u8]) -> Option<uint> { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.read(buf) }; match ret { Ok(num) => Some(num as uint), Err(_) => None match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { Ok(len) => Some(len as uint), Err(StreamEof) | Err(SslSessionClosed) => None, _ => unreachable!() } } Loading @@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> { impl<S: Stream> Writer for SslStream<S> { fn write(&mut self, buf: &[u8]) { let mut start = 0; while start < buf.len() { let ret = do self.in_retry_wrapper |ssl| { ssl.ssl.write(buf) ssl.write(buf.slice_from(start)) }; match ret { Ok(_) => (), Err(err) => fail2!("Write error: {:?}", err) Ok(len) => start += len as uint, _ => unreachable!() } self.write_through(); } } fn flush(&mut self) { self.write_through(); self.stream.flush(); self.stream.flush() } } impl<S: Stream> Decorator<S> for SslStream<S> { impl<S> Decorator<S> for SslStream<S> { fn inner(self) -> S { self.stream } Loading
tests.rs +11 −12 Original line number Diff line number Diff line Loading @@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil; use std::rt::io::net::tcp::TcpStream; use std::str; use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer}; use super::{Sslv23, SslContext, SslStream, SslVerifyPeer}; #[test] fn test_new_ctx() { SslCtx::new(Sslv23); SslContext::new(Sslv23); } #[test] fn test_new_sslstream() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); SslStream::new(&SslContext::new(Sslv23), stream); } #[test] fn test_verify_untrusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); match SslStream::new(ctx, stream) { match SslStream::try_new(&ctx, stream) { Ok(_) => fail2!("expected failure"), Err(err) => println!("error {}", err) Err(err) => println!("error {:?}", err) } } #[test] fn test_verify_trusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); ctx.set_verify_locations("cert.pem"); match SslStream::new(ctx, stream) { assert!(ctx.set_CA_file("cert.pem").is_none()); match SslStream::try_new(&ctx, stream) { Ok(_) => (), Err(err) => fail2!("Expected success, got {:?}", err) } Loading @@ -42,18 +42,17 @@ fn test_verify_trusted() { #[test] fn test_write() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("hello".as_bytes()); stream.flush(); stream.write(" there".as_bytes()); stream.flush(); stream.shutdown(); } #[test] fn test_read() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("GET /\r\n\r\n".as_bytes()); stream.flush(); let buf = stream.read_to_end(); Loading