Unverified Commit 5716dbf2 authored by Steven Fackler's avatar Steven Fackler
Browse files

Start on PkeyCtx

parent 7d000285
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -149,6 +149,16 @@ extern "C" {
    #[cfg(ossl111)]
    pub fn EVP_DigestFinalXOF(ctx: *mut EVP_MD_CTX, res: *mut u8, len: usize) -> c_int;

    #[cfg(ossl300)]
    pub fn EVP_MD_fetch(
        ctx: *mut OSSL_LIB_CTX,
        algorithm: *const c_char,
        properties: *const c_char,
    ) -> *mut EVP_MD;

    #[cfg(ossl300)]
    pub fn EVP_MD_free(md: *mut EVP_MD);

    pub fn EVP_BytesToKey(
        typ: *const EVP_CIPHER,
        md: *const EVP_MD,
+5 −0
Original line number Diff line number Diff line
@@ -153,6 +153,7 @@ pub mod fips;
pub mod hash;
#[cfg(ossl300)]
pub mod lib_ctx;
pub mod md;
pub mod memcmp;
pub mod nid;
#[cfg(not(osslconf = "OPENSSL_NO_OCSP"))]
@@ -161,6 +162,7 @@ pub mod pkcs12;
pub mod pkcs5;
pub mod pkcs7;
pub mod pkey;
pub mod pkey_ctx;
pub mod rand;
pub mod rsa;
pub mod sha;
@@ -173,6 +175,7 @@ pub mod symm;
pub mod version;
pub mod x509;

#[inline]
fn cvt_p<T>(r: *mut T) -> Result<*mut T, ErrorStack> {
    if r.is_null() {
        Err(ErrorStack::get())
@@ -181,6 +184,7 @@ fn cvt_p<T>(r: *mut T) -> Result<*mut T, ErrorStack> {
    }
}

#[inline]
fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
    if r <= 0 {
        Err(ErrorStack::get())
@@ -189,6 +193,7 @@ fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
    }
}

#[inline]
fn cvt_n(r: c_int) -> Result<c_int, ErrorStack> {
    if r < 0 {
        Err(ErrorStack::get())

openssl/src/md.rs

0 → 100644
+234 −0
Original line number Diff line number Diff line
#[cfg(ossl300)]
use crate::cvt_p;
#[cfg(ossl300)]
use crate::error::ErrorStack;
#[cfg(ossl300)]
use crate::lib_ctx::LibCtxRef;
use crate::nid::Nid;
use cfg_if::cfg_if;
use foreign_types::{ForeignTypeRef, Opaque};
#[cfg(ossl300)]
use std::ffi::CString;
#[cfg(ossl300)]
use std::ptr;

cfg_if! {
    if #[cfg(ossl300)] {
        use foreign_types::ForeignType;
        use std::ops::{Deref, DerefMut};

        type Inner = *mut ffi::EVP_MD;

        impl Drop for Md {
            #[inline]
            fn drop(&mut self) {
                unsafe {
                    ffi::EVP_MD_free(self.as_ptr());
                }
            }
        }

        impl ForeignType for Md {
            type CType = ffi::EVP_MD;
            type Ref = MdRef;

            #[inline]
            unsafe fn from_ptr(ptr: *mut Self::CType) -> Self {
                Md(ptr)
            }

            #[inline]
            fn as_ptr(&self) -> *mut Self::CType {
                self.0
            }
        }

        impl Deref for Md {
            type Target = MdRef;

            #[inline]
            fn deref(&self) -> &Self::Target {
                unsafe {
                    MdRef::from_ptr(self.as_ptr())
                }
            }
        }

        impl DerefMut for Md {
            #[inline]
            fn deref_mut(&mut self) -> &mut Self::Target {
                unsafe {
                    MdRef::from_ptr_mut(self.as_ptr())
                }
            }
        }
    } else {
        enum Inner {}
    }
}

/// A message digest algorithm.
pub struct Md(Inner);

unsafe impl Sync for Md {}
unsafe impl Send for Md {}

impl Md {
    /// Returns the `Md` corresponding to an [`Nid`].
    ///
    /// This corresponds to [`EVP_get_digestbynid`].
    ///
    /// [`EVP_get_digestbynid`]: https://www.openssl.org/docs/manmaster/crypto/EVP_DigestInit.html
    pub fn from_nid(type_: Nid) -> Option<&'static MdRef> {
        unsafe {
            let ptr = ffi::EVP_get_digestbynid(type_.as_raw());
            if ptr.is_null() {
                None
            } else {
                Some(MdRef::from_ptr(ptr as *mut _))
            }
        }
    }

    /// Fetches an `Md` object corresponding to the specified algorithm name and properties.
    ///
    /// This corresponds to [`EVP_MD_fetch`].
    ///
    /// Requires OpenSSL 3.0.0 or newer.
    ///
    /// [`EVP_MD_fetch`]: https://www.openssl.org/docs/manmaster/man3/EVP_MD_fetch.html
    #[cfg(ossl300)]
    pub fn fetch(
        ctx: Option<&LibCtxRef>,
        algorithm: &str,
        properties: Option<&str>,
    ) -> Result<Self, ErrorStack> {
        let algorithm = CString::new(algorithm).unwrap();
        let properties = properties.map(|s| CString::new(s).unwrap());

        unsafe {
            let ptr = cvt_p(ffi::EVP_MD_fetch(
                ctx.map_or(ptr::null_mut(), ForeignTypeRef::as_ptr),
                algorithm.as_ptr(),
                properties.map_or(ptr::null_mut(), |s| s.as_ptr()),
            ))?;

            Ok(Md::from_ptr(ptr))
        }
    }

    #[inline]
    pub fn null() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_md_null() as *mut _) }
    }

    #[inline]
    pub fn md5() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_md5() as *mut _) }
    }

    #[inline]
    pub fn sha1() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha1() as *mut _) }
    }

    #[inline]
    pub fn sha224() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha224() as *mut _) }
    }

    #[inline]
    pub fn sha256() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha256() as *mut _) }
    }

    #[inline]
    pub fn sha384() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha384() as *mut _) }
    }

    #[inline]
    pub fn sha512() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha512() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn sha3_224() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha3_224() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn sha3_256() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha3_256() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn sha3_384() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha3_384() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn sha3_512() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sha3_512() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn shake128() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_shake128() as *mut _) }
    }

    #[cfg(ossl111)]
    #[inline]
    pub fn shake256() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_shake256() as *mut _) }
    }

    #[cfg(not(osslconf = "OPENSSL_NO_RMD160"))]
    #[inline]
    pub fn ripemd160() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_ripemd160() as *mut _) }
    }

    #[cfg(not(osslconf = "OPENSSL_NO_SM3"))]
    #[inline]
    pub fn sm3() -> &'static MdRef {
        unsafe { MdRef::from_ptr(ffi::EVP_sm3() as *mut _) }
    }
}

/// A reference to an [`Md`].
pub struct MdRef(Opaque);

impl ForeignTypeRef for MdRef {
    type CType = ffi::EVP_MD;
}

unsafe impl Sync for MdRef {}
unsafe impl Send for MdRef {}

impl MdRef {
    /// Returns the size of the digest in bytes.
    ///
    /// This corresponds to [`EVP_MD_size`].
    ///
    /// [`EVP_MD_size`]: https://www.openssl.org/docs/manmaster/man3/EVP_MD_size.html
    #[inline]
    pub fn size(&self) -> usize {
        unsafe { ffi::EVP_MD_size(self.as_ptr()) as usize }
    }

    /// Returns the [`Nid`] of the digest.
    ///
    /// This corresponds to [`EVP_MD_type`].
    ///
    /// [`EVP_MD_type`]: https://www.openssl.org/docs/manmaster/man3/EVP_MD_type.html
    #[inline]
    pub fn type_(&self) -> Nid {
        unsafe { Nid::from_raw(ffi::EVP_MD_type(self.as_ptr())) }
    }
}
+328 −0
Original line number Diff line number Diff line
//! The asymmetric encryption context.
//!
//! # Examples
//!
//! Encrypt data with RSA
//!
//! ```
//! use openssl::rsa::Rsa;
//! use openssl::pkey::PKey;
//! use openssl::pkey_ctx::PkeyCtx;
//!
//! let key = Rsa::generate(4096).unwrap();
//! let key = PKey::from_rsa(key).unwrap();
//!
//! let mut ctx = PkeyCtx::new(&key).unwrap();
//! ctx.encrypt_init().unwrap();
//!
//! let data = b"Some Crypto Text";
//! let mut ciphertext = vec![];
//! ctx.encrypt_to_vec(data, &mut ciphertext).unwrap();
//! ```
use crate::error::ErrorStack;
use crate::md::MdRef;
use crate::pkey::PKeyRef;
use crate::rsa::Padding;
#[cfg(any(ossl102, libressl310))]
use crate::util;
use crate::{cvt, cvt_p};
use foreign_types::{ForeignType, ForeignTypeRef};
#[cfg(any(ossl102, libressl310))]
use libc::c_int;
#[cfg(any(ossl102, libressl310))]
use std::convert::TryFrom;
use std::ptr;

foreign_type_and_impl_send_sync! {
    type CType = ffi::EVP_PKEY_CTX;
    fn drop = ffi::EVP_PKEY_CTX_free;

    pub struct PkeyCtx;
    /// A reference to a [`PkeyCtx`].
    pub struct PkeyCtxRef;
}

impl PkeyCtx {
    #[inline]
    pub fn new<T>(pkey: &PKeyRef<T>) -> Result<Self, ErrorStack> {
        unsafe {
            let ptr = cvt_p(ffi::EVP_PKEY_CTX_new(pkey.as_ptr(), ptr::null_mut()))?;
            Ok(PkeyCtx::from_ptr(ptr))
        }
    }
}

impl PkeyCtxRef {
    /// Returns the RSA padding mode in use.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_get_rsa_padding`].
    ///
    /// [`EVP_PKEY_CTX_get_rsa_padding`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_get_rsa_padding.html
    #[inline]
    pub fn rsa_padding(&self) -> Result<Padding, ErrorStack> {
        let mut pad = 0;
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_get_rsa_padding(self.as_ptr(), &mut pad))?;
        }

        Ok(Padding::from_raw(pad))
    }

    /// Sets the RSA padding mode.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set_rsa_padding`].
    ///
    /// [`EVP_PKEY_CTX_set_rsa_padding`]: https://www.openssl.org/docs/manmaster/crypto/EVP_PKEY_CTX_set_rsa_padding.html
    #[inline]
    pub fn set_rsa_padding(&mut self, padding: Padding) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set_rsa_padding(
                self.as_ptr(),
                padding.as_raw(),
            ))?;
        }

        Ok(())
    }

    /// Sets the RSA MGF1 algorithm.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set_rsa_mgf1_md`].
    ///
    /// [`EVP_PKEY_CTX_set_rsa_mgf1_md`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_set_rsa_mgf1_md.html
    #[inline]
    pub fn set_rsa_mgf1_md(&mut self, md: &MdRef) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set_rsa_mgf1_md(
                self.as_ptr(),
                md.as_ptr(),
            ))?;
        }

        Ok(())
    }

    /// Sets the RSA OAEP algorithm.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set_rsa_oaep_md`].
    ///
    /// [`EVP_PKEY_CTX_set_rsa_oaep_md`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_set_rsa_oaep_md.html
    #[cfg(any(ossl102, libressl310))]
    #[inline]
    pub fn set_rsa_oaep_md(&mut self, md: &MdRef) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set_rsa_oaep_md(
                self.as_ptr(),
                md.as_ptr() as *mut _,
            ))?;
        }

        Ok(())
    }

    /// Sets the RSA OAEP label.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set0_rsa_oaep_label`].
    ///
    /// [`EVP_PKEY_CTX_set0_rsa_oaep_label`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_set0_rsa_oaep_label.html
    #[cfg(any(ossl102, libressl310))]
    pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
        let len = c_int::try_from(label.len()).unwrap();

        unsafe {
            let p = util::crypto_malloc(label.len())?;
            ptr::copy_nonoverlapping(label.as_ptr(), p as *mut _, label.len());

            let r = cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(self.as_ptr(), p, len));
            if r.is_err() {
                util::crypto_free(p);
            }
            r?;
        }

        Ok(())
    }

    /// Prepares the context for encryption using the public key.
    ///
    /// This corresponds to [`EVP_PKEY_encrypt_init`].
    ///
    /// [`EVP_PKEY_encrypt_init`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_encrypt_init.html
    #[inline]
    pub fn encrypt_init(&mut self) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_encrypt_init(self.as_ptr()))?;
        }

        Ok(())
    }

    /// Prepares the context for encryption using the private key.
    ///
    /// This corresponds to [`EVP_PKEY_decrypt_init`].
    ///
    /// [`EVP_PKEY_decrypt_init`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_decrypt_init.html
    #[inline]
    pub fn decrypt_init(&mut self) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_decrypt_init(self.as_ptr()))?;
        }

        Ok(())
    }

    /// Encrypts data using the public key.
    ///
    /// If `to` is set to `None`, the an upper bound on the number of bytes required for the output buffer will be
    /// returned.
    ///
    /// This corresponds to [`EVP_PKEY_encrypt`].
    ///
    /// [`EVP_PKEY_encrypt`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_encrypt.html
    #[inline]
    pub fn encrypt(&mut self, from: &[u8], to: Option<&mut [u8]>) -> Result<usize, ErrorStack> {
        let mut written = to.as_ref().map_or(0, |b| b.len());
        unsafe {
            cvt(ffi::EVP_PKEY_encrypt(
                self.as_ptr(),
                to.map_or(ptr::null_mut(), |b| b.as_mut_ptr()),
                &mut written,
                from.as_ptr(),
                from.len(),
            ))?;
        }

        Ok(written)
    }

    /// Like [`Self::encrypt`] but appends ciphertext to a [`Vec`].
    pub fn encrypt_to_vec(&mut self, from: &[u8], out: &mut Vec<u8>) -> Result<usize, ErrorStack> {
        let base = out.len();
        let len = self.encrypt(from, None)?;
        out.resize(base + len, 0);
        let len = self.encrypt(from, Some(&mut out[base..]))?;
        out.truncate(base + len);
        Ok(len)
    }

    /// Decrypts data using the private key.
    ///
    /// If `to` is set to `None`, the an upper bound on the number of bytes required for the output buffer will be
    /// returned.
    ///
    /// This corresponds to [`EVP_PKEY_decrypt`].
    ///
    /// [`EVP_PKEY_decrypt`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_encrypt.html
    #[inline]
    pub fn decrypt(&mut self, from: &[u8], to: Option<&mut [u8]>) -> Result<usize, ErrorStack> {
        let mut written = to.as_ref().map_or(0, |b| b.len());
        unsafe {
            cvt(ffi::EVP_PKEY_decrypt(
                self.as_ptr(),
                to.map_or(ptr::null_mut(), |b| b.as_mut_ptr()),
                &mut written,
                from.as_ptr(),
                from.len(),
            ))?;
        }

        Ok(written)
    }

    /// Like [`Self::decrypt`] but appends plaintext to a [`Vec`].
    pub fn decrypt_to_vec(&mut self, from: &[u8], out: &mut Vec<u8>) -> Result<usize, ErrorStack> {
        let base = out.len();
        let len = self.decrypt(from, None)?;
        out.resize(base + len, 0);
        let len = self.decrypt(from, Some(&mut out[base..]))?;
        out.truncate(base + len);
        Ok(len)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::md::Md;
    use crate::pkey::PKey;
    use crate::rsa::Rsa;

    #[test]
    fn decrypt_without_private_key() {
        let key = Rsa::public_key_from_pem(include_bytes!("../test/rsa.pem.pub")).unwrap();
        let key = PKey::from_rsa(key).unwrap();

        let mut ctx = PkeyCtx::new(&key).unwrap();

        let pt = "hello".as_bytes();

        ctx.encrypt_init().unwrap();
        let mut ct = vec![];
        ctx.encrypt_to_vec(pt, &mut ct).unwrap();

        ctx.decrypt_init().unwrap();
        let mut out = vec![];
        ctx.decrypt_to_vec(&ct, &mut out).unwrap_err();
    }

    #[test]
    fn rsa() {
        let key = include_bytes!("../test/rsa.pem");
        let rsa = Rsa::private_key_from_pem(key).unwrap();
        let pkey = PKey::from_rsa(rsa).unwrap();

        let mut ctx = PkeyCtx::new(&pkey).unwrap();
        ctx.encrypt_init().unwrap();
        ctx.set_rsa_padding(Padding::PKCS1).unwrap();

        let pt = "hello world".as_bytes();
        let mut ct = vec![];
        ctx.encrypt_to_vec(pt, &mut ct).unwrap();

        ctx.decrypt_init().unwrap();
        ctx.set_rsa_padding(Padding::PKCS1).unwrap();

        let mut out = vec![];
        ctx.decrypt_to_vec(&ct, &mut out).unwrap();

        assert_eq!(pt, out);
    }

    #[test]
    #[cfg(any(ossl102, libressl310))]
    fn rsa_oaep() {
        let key = include_bytes!("../test/rsa.pem");
        let rsa = Rsa::private_key_from_pem(key).unwrap();
        let pkey = PKey::from_rsa(rsa).unwrap();

        let mut ctx = PkeyCtx::new(&pkey).unwrap();
        ctx.encrypt_init().unwrap();
        ctx.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
        ctx.set_rsa_oaep_md(Md::sha256()).unwrap();
        ctx.set_rsa_mgf1_md(Md::sha256()).unwrap();

        let pt = "hello world".as_bytes();
        let mut ct = vec![];
        ctx.encrypt_to_vec(pt, &mut ct).unwrap();

        ctx.decrypt_init().unwrap();
        ctx.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
        ctx.set_rsa_oaep_md(Md::sha256()).unwrap();
        ctx.set_rsa_mgf1_md(Md::sha256()).unwrap();

        let mut out = vec![];
        ctx.decrypt_to_vec(&ct, &mut out).unwrap();

        assert_eq!(pt, out);
    }
}
+31 −2
Original line number Diff line number Diff line
use crate::cvt_p;
use crate::error::ErrorStack;
use cfg_if::cfg_if;
use foreign_types::{ForeignType, ForeignTypeRef};
use libc::{c_char, c_int, c_void};
use std::any::Any;
use std::convert::TryFrom;
use std::panic::{self, AssertUnwindSafe};
use std::slice;

use crate::error::ErrorStack;

/// Wraps a user-supplied callback and a slot for panics thrown inside the callback (while FFI
/// frames are on the stack).
///
@@ -92,3 +94,30 @@ pub trait ForeignTypeRefExt: ForeignTypeRef {
    }
}
impl<FT: ForeignTypeRef> ForeignTypeRefExt for FT {}

#[track_caller]
#[inline]
pub fn crypto_malloc(len: usize) -> Result<*mut c_void, ErrorStack> {
    // 1.0.2 uses c_int but 1.1.0+ uses size_t
    let len = TryFrom::try_from(len).unwrap();

    unsafe {
        cvt_p(ffi::CRYPTO_malloc(
            len,
            concat!(file!(), "\0").as_ptr() as *const _,
            line!() as _,
        ))
    }
}

#[track_caller]
#[inline]
pub unsafe fn crypto_free(ptr: *mut c_void) {
    cfg_if! {
        if #[cfg(ossl110)] {
            ffi::CRYPTO_free(ptr, concat!(file!(), "\0").as_ptr() as *const _, line!() as _);
        } else {
            ffi::CRYPTO_free(ptr);
        }
    }
}