Commit 63f32c1a authored by Ivan Paljak's avatar Ivan Paljak
Browse files

Enable setting RSA OAEP label in Decrypter

parent d8aa8a9a
Loading
Loading
Loading
Loading
+59 −0
Original line number Diff line number Diff line
@@ -363,6 +363,32 @@ impl<'a> Decrypter<'a> {
        }
    }

    /// Sets the RSA OAEP label.
    ///
    /// This is only useful for RSA keys.
    ///
    /// This corresponds to [`EVP_PKEY_CTX_set0_rsa_oaep_label`].
    ///
    /// [`EVP_PKEY_CTX_set0_rsa_oaep_label`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_set0_rsa_oaep_label.html
    #[cfg(any(ossl102, libressl310))]
    pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> {
        unsafe {
            let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?;
            ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len());

            cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label(
                self.pctx,
                p as *mut c_void,
                label.len() as c_int,
            ))
            .map(|_| ())
            .map_err(|e| {
                ffi::OPENSSL_free(p);
                e
            })
        }
    }

    /// Performs public key decryption.
    ///
    /// In order to know the size needed for the output buffer, use [`decrypt_len`](Decrypter::decrypt_len).
@@ -516,4 +542,37 @@ mod test {

        assert_eq!(decoded, &*input);
    }

    #[test]
    #[cfg(any(ossl102, libressl310))]
    fn rsa_encrypt_decrypt_oaep_label() {
        let key = include_bytes!("../test/rsa.pem");
        let private_key = Rsa::private_key_from_pem(key).unwrap();
        let pkey = PKey::from_rsa(private_key).unwrap();

        let mut encrypter = Encrypter::new(&pkey).unwrap();
        encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
        encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
        let input = Vec::from_hex(INPUT).unwrap();
        let buffer_len = encrypter.encrypt_len(&input).unwrap();
        let mut encoded = vec![0u8; buffer_len];
        let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap();
        let encoded = &encoded[..encoded_len];

        let mut decrypter = Decrypter::new(&pkey).unwrap();
        decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap();
        decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap();
        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
        let mut decoded = vec![0u8; buffer_len];
        let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap();
        let decoded = &decoded[..decoded_len];

        assert_eq!(decoded, &*input);

        decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap();
        let buffer_len = decrypter.decrypt_len(encoded).unwrap();
        let mut decoded = vec![0u8; buffer_len];

        assert!(decrypter.decrypt(encoded, &mut decoded).is_err());
    }
}