Commit b0c0c72f authored by Tom Leavy's avatar Tom Leavy
Browse files

Added PKeyContext supporting ECDH + HKDF

parent c65d85af
Loading
Loading
Loading
Loading
+124 −0
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ pub const EVP_PKEY_X448: c_int = NID_X448;
pub const EVP_PKEY_ED448: c_int = NID_ED448;
pub const EVP_PKEY_HMAC: c_int = NID_hmac;
pub const EVP_PKEY_CMAC: c_int = NID_cmac;
#[cfg(ossl110)]
pub const EVP_PKEY_HKDF: c_int = NID_hkdf;

pub const EVP_CTRL_GCM_SET_IVLEN: c_int = 0x9;
pub const EVP_CTRL_GCM_GET_TAG: c_int = 0x10;
@@ -467,6 +469,7 @@ cfg_if! {
        pub const EVP_PKEY_OP_VERIFYCTX: c_int = 1 << 8;
        pub const EVP_PKEY_OP_ENCRYPT: c_int = 1 << 9;
        pub const EVP_PKEY_OP_DECRYPT: c_int = 1 << 10;
        pub const EVP_PKEY_OP_DERIVE: c_int = 1 << 11;
    } else {
        pub const EVP_PKEY_OP_SIGN: c_int = 1 << 3;
        pub const EVP_PKEY_OP_VERIFY: c_int = 1 << 4;
@@ -475,6 +478,7 @@ cfg_if! {
        pub const EVP_PKEY_OP_VERIFYCTX: c_int = 1 << 7;
        pub const EVP_PKEY_OP_ENCRYPT: c_int = 1 << 8;
        pub const EVP_PKEY_OP_DECRYPT: c_int = 1 << 9;
        pub const EVP_PKEY_OP_DERIVE: c_int = 1 << 10;
    }
}

@@ -492,6 +496,30 @@ pub const EVP_PKEY_CTRL_CIPHER: c_int = 12;

pub const EVP_PKEY_ALG_CTRL: c_int = 0x1000;

#[cfg(ossl111)]
pub const EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND: c_int = 0;

#[cfg(ossl111)]
pub const EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY: c_int = 1;

#[cfg(ossl111)]
pub const EVP_PKEY_HKDEF_MODE_EXPAND_ONLY: c_int = 2;

#[cfg(ossl110)]
pub const EVP_PKEY_CTRL_HKDF_MD: c_int = EVP_PKEY_ALG_CTRL + 3;

#[cfg(ossl110)]
pub const EVP_PKEY_CTRL_HKDF_SALT: c_int = EVP_PKEY_ALG_CTRL + 4;

#[cfg(ossl110)]
pub const EVP_PKEY_CTRL_HKDF_KEY: c_int = EVP_PKEY_ALG_CTRL + 5;

#[cfg(ossl110)]
pub const EVP_PKEY_CTRL_HKDF_INFO: c_int = EVP_PKEY_ALG_CTRL + 6;

#[cfg(ossl111)]
pub const EVP_PKEY_CTRL_HKDF_MODE: c_int = EVP_PKEY_ALG_CTRL + 7;

extern "C" {
    pub fn EVP_PKEY_CTX_new(k: *mut EVP_PKEY, e: *mut ENGINE) -> *mut EVP_PKEY_CTX;
    pub fn EVP_PKEY_CTX_new_id(id: c_int, e: *mut ENGINE) -> *mut EVP_PKEY_CTX;
@@ -538,6 +566,102 @@ extern "C" {
    ) -> c_int;
}

// HKDF Functions
cfg_if! {
    if #[cfg(ossl300)] {
        extern "C" {
            pub fn EVP_PKEY_CTX_set_hkdf_mode(ctx: *mut EVP_PKEY_CTX, mode: c_int) -> c_int;
            pub fn EVP_PKEY_CTX_set_hkdf_md(ctx: *mut EVP_PKEY_CTX, md: *const EVP_MD) -> c_int;
            pub fn EVP_PKEY_CTX_set1_hkdf_salt(
                ctx: *mut EVP_PKEY_CTX,
                salt: *const u8,
                saltlen: c_int,
            ) -> c_int;
            pub fn EVP_PKEY_CTX_set1_hkdf_key(
                ctx: *mut EVP_PKEY_CTX,
                key: *const u8,
                keylen: c_int,
            ) -> c_int;
            pub fn EVP_PKEY_CTX_add1_hkdf_info(
                ctx: *mut EVP_PKEY_CTX,
                info: *const u8,
                infolen: c_int,
            ) -> c_int;
        }
    } else {
        #[cfg(ossl111)]
        pub unsafe fn EVP_PKEY_CTX_set_hkdf_mode(ctx: *mut EVP_PKEY_CTX, mode: c_int) -> c_int {
            EVP_PKEY_CTX_ctrl(
                ctx,
                -1,
                EVP_PKEY_OP_DERIVE,
                EVP_PKEY_CTRL_HKDF_MODE,
                mode, std::ptr::null_mut(),
            )
        }

        #[cfg(ossl110)]
        pub unsafe fn EVP_PKEY_CTX_set_hkdf_md(ctx: *mut EVP_PKEY_CTX, md: *const EVP_MD) -> c_int {
            EVP_PKEY_CTX_ctrl(
                ctx,
                -1,
                EVP_PKEY_OP_DERIVE,
                EVP_PKEY_CTRL_HKDF_MD,
                0,
                md as *mut c_void,
            )
        }

        #[cfg(ossl110)]
        pub unsafe fn EVP_PKEY_CTX_set1_hkdf_salt(
            ctx: *mut EVP_PKEY_CTX,
            salt: *const u8,
            saltlen: c_int,
        ) -> c_int {
            EVP_PKEY_CTX_ctrl(
                ctx,
                -1,
                EVP_PKEY_OP_DERIVE,
                EVP_PKEY_CTRL_HKDF_SALT,
                saltlen,
                salt as *mut c_void,
            )
        }

        #[cfg(ossl110)]
        pub unsafe fn EVP_PKEY_CTX_set1_hkdf_key(
            ctx: *mut EVP_PKEY_CTX,
            key: *const u8,
            keylen: c_int,
        ) -> c_int {
            EVP_PKEY_CTX_ctrl(
                ctx,
                -1,
                EVP_PKEY_OP_DERIVE,
                EVP_PKEY_CTRL_HKDF_KEY,
                keylen,
                key as *mut c_void,
            )
        }

        #[cfg(ossl110)]
        pub unsafe fn EVP_PKEY_CTX_add1_hkdf_info(
            ctx: *mut EVP_PKEY_CTX,
            info: *const u8,
            infolen: c_int,
        ) -> c_int {
            EVP_PKEY_CTX_ctrl(
                ctx,
                -1,
                EVP_PKEY_OP_DERIVE,
                EVP_PKEY_CTRL_HKDF_INFO,
                infolen,
                info as *mut c_void,
            )
        }
    }
}

const_ptr_api! {
    extern "C" {
        pub fn EVP_PKCS82PKEY(p8: #[const_ptr_if(any(ossl110, libressl280))] PKCS8_PRIV_KEY_INFO) -> *mut EVP_PKEY;
+2 −0
Original line number Diff line number Diff line
@@ -916,6 +916,8 @@ pub const NID_aes_256_cbc_hmac_sha1: c_int = 918;
pub const NID_X25519: c_int = 1034;
#[cfg(ossl111)]
pub const NID_X448: c_int = 1035;
#[cfg(ossl110)]
pub const NID_hkdf: c_int = 1036;
#[cfg(ossl111)]
pub const NID_ED25519: c_int = 1087;
#[cfg(ossl111)]
+1 −0
Original line number Diff line number Diff line
@@ -157,6 +157,7 @@ pub mod pkcs12;
pub mod pkcs5;
pub mod pkcs7;
pub mod pkey;
pub mod pkey_ctx;
pub mod rand;
pub mod rsa;
pub mod sha;
+3 −0
Original line number Diff line number Diff line
@@ -80,6 +80,9 @@ impl Id {
    pub const DH: Id = Id(ffi::EVP_PKEY_DH);
    pub const EC: Id = Id(ffi::EVP_PKEY_EC);

    #[cfg(ossl110)]
    pub const HKDF: Id = Id(ffi::EVP_PKEY_HKDF);

    #[cfg(ossl111)]
    pub const ED25519: Id = Id(ffi::EVP_PKEY_ED25519);
    #[cfg(ossl111)]
+394 −0
Original line number Diff line number Diff line
//! PKeyContext operations
use crate::error::ErrorStack;
#[cfg(ossl110)]
use crate::hash::MessageDigest;
use crate::pkey::{HasPrivate, HasPublic, Id, PKeyRef};
use crate::{cvt, cvt_p};
use foreign_types::ForeignTypeRef;
use std::marker::PhantomData;
#[cfg(ossl110)]
use std::os::raw::{c_int, c_uchar};
use std::ptr;

/// HKDF modes of operation. See [`hkdf_mode`](PKeyContext::set_hkdf_mode)
#[cfg(ossl111)]
pub struct HkdfMode(c_int);

impl HkdfMode {
    /// Extract followed by expand
    pub const EXTRACT_THEN_EXPAND: HkdfMode = HkdfMode(ffi::EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND);
    /// Extract only HKDF
    pub const EXTRACT: HkdfMode = HkdfMode(ffi::EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY);
    /// Expand only HKDF
    pub const EXPAND: HkdfMode = HkdfMode(ffi::EVP_PKEY_HKDEF_MODE_EXPAND_ONLY);
}

/// A type used for performing operations with PKeys
pub struct PKeyContext<'a>(*mut ffi::EVP_PKEY_CTX, PhantomData<&'a ()>);

unsafe impl<'a> Sync for PKeyContext<'a> {}
unsafe impl<'a> Send for PKeyContext<'a> {}

#[allow(clippy::len_without_is_empty)]
impl<'a> PKeyContext<'a> {
    /// Creates a new `PKeyContext` using the provided private key.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_new`].
    ///
    /// [`EVP_PKEY_CTX_new`]: https://www.openssl.org/docs/man1.0.2/crypto/EVP_PKEY_CTX_new.html
    pub fn new<T>(key: &'a PKeyRef<T>) -> Result<PKeyContext<'a>, ErrorStack>
    where
        T: HasPrivate,
    {
        unsafe {
            cvt_p(ffi::EVP_PKEY_CTX_new(key.as_ptr(), ptr::null_mut()))
                .map(|p| PKeyContext(p, PhantomData))
        }
    }

    /// Initialize the PkeyContext for key derivation operations
    ///
    /// This corresponds to [`EVP_PKEY_derive_init`].
    ///
    /// [`EVP_PKEY_derive_init`]: https://www.openssl.org/docs/man1.0.2/crypto/EVP_PKEY_derive_init.html
    pub fn derive_init(&self) -> Result<(), ErrorStack> {
        unsafe { cvt(ffi::EVP_PKEY_derive_init(self.0)).map(|_| ()) }
    }

    /// Sets the peer key used for secret derivation.
    ///
    /// This corresponds to [`EVP_PKEY_derive_set_peer`]:
    ///
    /// [`EVP_PKEY_derive_set_peer`]: https://www.openssl.org/docs/man1.0.2/crypto/EVP_PKEY_derive_init.html
    pub fn set_peer<T>(&mut self, key: &'a PKeyRef<T>) -> Result<(), ErrorStack>
    where
        T: HasPublic,
    {
        unsafe { cvt(ffi::EVP_PKEY_derive_set_peer(self.0, key.as_ptr())).map(|_| ()) }
    }

    /// Creates a new `PKeyContext` using the algorithm specified by `id`.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_new_id`]
    ///
    /// [`EVP_PKEY_CTX_new_id`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_new_id.html
    #[cfg(ossl110)]
    pub fn new_id(id: Id) -> Result<Self, ErrorStack> {
        unsafe {
            cvt_p(ffi::EVP_PKEY_CTX_new_id(id.as_raw(), ptr::null_mut()))
                .map(|p| PKeyContext(p, PhantomData))
        }
    }

    /// Sets the digest to use for HKDF derivation.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set_hkdf_md`].
    ///
    /// # Warning
    /// This function will result in an error unless the `PKeyContext` was created with
    /// [`new_id`](PKeyContext::new_id) specifying the `HKDF` Id.
    ///
    /// [`EVP_PKEY_CTX_set_hkdf_md`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_set_hkdf_md.html
    #[cfg(ossl110)]
    pub fn set_hkdf_md(&mut self, digest: MessageDigest) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set_hkdf_md(self.0, digest.as_ptr()))?;
        }

        Ok(())
    }

    /// Sets the HKDF mode of operation.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_hkdf_mode`]
    ///
    /// # Warning
    /// This function will result in an error unless the context was created with
    /// [new_id](PKeyContext::new_id) specifying the `HKDF` Id.
    ///
    /// [`EVP_PKEY_CTX_hkdf_mode`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_hkdf_mode.html
    #[cfg(ossl111)]
    pub fn set_hkdf_mode(&mut self, mode: HkdfMode) -> Result<(), ErrorStack> {
        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set_hkdf_mode(self.0, mode.0))?;
        }

        Ok(())
    }

    /// Sets the input keying material for HKDF derivation.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set1_hkdf_key`].
    ///
    /// [`EVP_PKEY_CTX_set1_hkdf_key`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_set1_hkdf_key.html
    #[cfg(ossl110)]
    pub fn set_hkdf_key(&mut self, key: &[u8]) -> Result<(), ErrorStack> {
        let len = key.len();
        assert!(len <= std::i32::MAX as usize);

        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set1_hkdf_key(
                self.0,
                key.as_ptr() as *mut c_uchar,
                len as c_int,
            ))?;

            Ok(())
        }
    }

    /// Sets the salt value for HKDF derivation.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set1_hkdf_salt`].
    ///
    /// [`EVP_PKEY_CTX_set1_hkdf_salt`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_set1_hkdf_salt.html
    #[cfg(ossl110)]
    pub fn set_hkdf_salt(&mut self, salt: &[u8]) -> Result<(), ErrorStack> {
        let len = salt.len();
        assert!(len <= std::i32::MAX as usize);

        unsafe {
            cvt(ffi::EVP_PKEY_CTX_set1_hkdf_salt(
                self.0,
                salt.as_ptr() as *mut c_uchar,
                len as c_int,
            ))?;

            Ok(())
        }
    }

    /// Appends info bytes for HKDF derivation.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_add1_hkdf_info`].
    ///
    /// # Warning
    ///
    /// On OpenSSL versions < 3.0, total length of the `info` buffer must not exceed 1024 bytes
    /// in length
    ///
    /// [`EVP_PKEY_CTX_add1_hkdf_info`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_CTX_add1_hkdf_info.html
    #[cfg(ossl110)]
    pub fn add_hkdf_info(&mut self, info: &[u8]) -> Result<(), ErrorStack> {
        let len = info.len();
        assert!(len <= std::i32::MAX as usize);

        unsafe {
            cvt(ffi::EVP_PKEY_CTX_add1_hkdf_info(
                self.0,
                info.as_ptr() as *mut c_uchar,
                len as c_int,
            ))?;

            Ok(())
        }
    }

    /// Returns the size of the derivation output.
    ///
    /// It can be used to size the buffer passed to [`PKeyContext::derive`].
    ///
    /// This corresponds to [`EVP_PKEY_derive`].
    ///
    /// # Warning
    ///
    /// When using this `PKeyContext` for HKDF, this function is only allowed when using HKDF with
    /// [`EXTRACT`](HkdfMode::EXTRACT) mode.
    ///
    /// [`PKeyContext::derive`]: #method.derive
    /// [`EVP_PKEY_derive`]: https://www.openssl.org/docs/man1.0.2/crypto/EVP_PKEY_derive_init.html
    pub fn derive_len(&mut self) -> Result<usize, ErrorStack> {
        unsafe {
            let mut len = 0;
            cvt(ffi::EVP_PKEY_derive(self.0, ptr::null_mut(), &mut len)).map(|_| len)
        }
    }

    /// Derive the configured output based on algorithm
    ///
    /// This corresponds to [`EVP_PKEY_derive`]
    ///
    /// [`EVP_PKEY_derive`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_PKEY_derive.html
    pub fn derive(&mut self, buf: &mut [u8]) -> Result<usize, ErrorStack> {
        unsafe {
            let mut len = buf.len();
            cvt(ffi::EVP_PKEY_derive(self.0, buf.as_mut_ptr(), &mut len))?;
            Ok(len)
        }
    }

    /// A convenience function which derives a shared secret and returns it in a new buffer.
    ///
    /// This simply wraps [`PKeyContext::len`] and [`PKeyContext::derive`].
    ///
    /// # Warning
    ///
    /// When using this `PKeyContext` for HKDF, this function is only allowed when using HKDF with
    /// [`EXTRACT`](HkdfMode::EXTRACT) mode.
    /// [`PKeyContext::len`]: #method.len
    /// [`PKeyContext::derive`]: #method.derive
    pub fn derive_to_vec(&mut self) -> Result<Vec<u8>, ErrorStack> {
        let len = self.derive_len()?;
        let mut buf = vec![0; len];
        let len = self.derive(&mut buf)?;
        buf.truncate(len);
        Ok(buf)
    }
}

impl<'a> Drop for PKeyContext<'a> {
    fn drop(&mut self) {
        unsafe {
            ffi::EVP_PKEY_CTX_free(self.0);
        }
    }
}

/// One-shot HKDF expand, filling the buffer
#[cfg(ossl111)]
pub fn hkdf_expand(
    digest: MessageDigest,
    key: &[u8],
    info: &[u8],
    buf: &mut [u8],
) -> Result<(), ErrorStack> {
    let mut ctx = PKeyContext::new_id(Id::HKDF)?;
    ctx.derive_init()?;
    ctx.set_hkdf_md(digest)?;
    ctx.set_hkdf_mode(HkdfMode::EXPAND)?;
    ctx.set_hkdf_key(key)?;
    ctx.add_hkdf_info(info)?;
    ctx.derive(buf)?;

    Ok(())
}

/// One-shot HKDF extract
#[cfg(ossl111)]
pub fn hkdf_extract(digest: MessageDigest, key: &[u8], salt: &[u8]) -> Result<Vec<u8>, ErrorStack> {
    let mut ctx = PKeyContext::new_id(Id::HKDF)?;
    ctx.derive_init()?;
    ctx.set_hkdf_md(digest)?;
    ctx.set_hkdf_mode(HkdfMode::EXTRACT)?;
    ctx.set_hkdf_key(key)?;
    ctx.set_hkdf_salt(salt)?;

    let mut buf = vec![0u8; ctx.derive_len()?];
    ctx.derive(&mut buf)?;
    Ok(buf)
}

/// One-shot HKDF extract-and-expand, filling the buffer
#[cfg(ossl110)]
pub fn hkdf(
    digest: MessageDigest,
    key: &[u8],
    salt: &[u8],
    info: &[u8],
    buf: &mut [u8],
) -> Result<(), ErrorStack> {
    let mut ctx = PKeyContext::new_id(Id::HKDF)?;
    ctx.derive_init()?;
    ctx.set_hkdf_md(digest)?;
    ctx.set_hkdf_key(key)?;
    ctx.set_hkdf_salt(salt)?;
    ctx.add_hkdf_info(info)?;
    ctx.derive(buf)?;

    Ok(())
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::ec::{EcGroup, EcKey};
    use crate::nid::Nid;
    use crate::pkey::PKey;
    #[cfg(ossl110)]
    use hex::{self, FromHex};

    // ECDH DERIVE TESTS
    #[test]
    fn derive_without_peer() {
        let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
        let ec_key = EcKey::generate(&group).unwrap();
        let pkey = PKey::from_ec_key(ec_key).unwrap();
        let mut ctx = PKeyContext::new(&pkey).unwrap();
        ctx.derive_init().unwrap();
        ctx.derive_to_vec().unwrap_err();
    }

    #[test]
    fn test_ec_key_derive() {
        let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
        let ec_key = EcKey::generate(&group).unwrap();
        let ec_key2 = EcKey::generate(&group).unwrap();
        let pkey = PKey::from_ec_key(ec_key).unwrap();
        let pkey2 = PKey::from_ec_key(ec_key2).unwrap();
        let mut ctx = PKeyContext::new(&pkey).unwrap();
        ctx.derive_init().unwrap();
        ctx.set_peer(&pkey2).unwrap();
        let shared = ctx.derive_to_vec().unwrap();
        assert!(!shared.is_empty());
    }

    // HKDF DERIVE TESTS
    #[cfg(ossl110)]
    const IKM: &str = "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b";
    #[cfg(ossl110)]
    const SALT: &str = "000102030405060708090a0b0c";
    #[cfg(ossl110)]
    const INFO: &str = "f0f1f2f3f4f5f6f7f8f9";
    #[cfg(ossl110)]
    const L: usize = 42;

    #[cfg(ossl111)]
    const PRK: &str = "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5";

    #[cfg(ossl110)]
    const OKM: &str = "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf\
                       34007208d5b887185865";

    #[cfg(ossl111)]
    #[test]
    fn test_hkdf_expand() {
        let ikm = Vec::from_hex(PRK).unwrap();
        let info = Vec::from_hex(INFO).unwrap();
        let mut out = vec![0u8; L];

        hkdf_expand(MessageDigest::sha256(), &ikm, &info, &mut out).unwrap();
        assert_eq!(out, Vec::from_hex(OKM).unwrap());
    }

    #[cfg(ossl111)]
    #[test]
    fn test_hkdf_extract() {
        let ikm = Vec::from_hex(IKM).unwrap();
        let salt = Vec::from_hex(SALT).unwrap();
        let out = hkdf_extract(MessageDigest::sha256(), &ikm, &salt).unwrap();
        assert_eq!(out, Vec::from_hex(PRK).unwrap());
    }

    #[cfg(ossl110)]
    #[test]
    fn test_hkdf() {
        let ikm = Vec::from_hex(IKM).unwrap();
        let salt = Vec::from_hex(SALT).unwrap();
        let info = Vec::from_hex(INFO).unwrap();
        let mut out = vec![0u8; L];

        hkdf(MessageDigest::sha256(), &ikm, &salt, &info, &mut out).unwrap();
        assert_eq!(out, Vec::from_hex(OKM).unwrap());
    }

    #[cfg(ossl110)]
    #[cfg(not(ossl300))]
    #[test]
    fn test_large_hkdf_info() {
        let too_big = vec![0u8; 1025];
        let mut hkdf = PKeyContext::new_id(Id::HKDF).unwrap();
        hkdf.derive_init().unwrap();
        hkdf.set_hkdf_md(MessageDigest::sha256()).unwrap();
        assert!(hkdf.add_hkdf_info(&too_big).is_err());
    }
}
Loading