diff --git a/openssl/src/encrypt.rs b/openssl/src/encrypt.rs index e2acd3f78598a3330adde1db47351bf5fee07dee..3cb10fcca2fb2c3c9480c0bb3aa3a6f4ccea2f6d 100644 --- a/openssl/src/encrypt.rs +++ b/openssl/src/encrypt.rs @@ -363,6 +363,32 @@ impl<'a> Decrypter<'a> { } } + /// Sets the RSA OAEP label. + /// + /// This is only useful for RSA keys. + /// + /// This corresponds to [`EVP_PKEY_CTX_set0_rsa_oaep_label`]. + /// + /// [`EVP_PKEY_CTX_set0_rsa_oaep_label`]: https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_CTX_set0_rsa_oaep_label.html + #[cfg(any(ossl102, libressl310))] + pub fn set_rsa_oaep_label(&mut self, label: &[u8]) -> Result<(), ErrorStack> { + unsafe { + let p = cvt_p(ffi::OPENSSL_malloc(label.len() as _))?; + ptr::copy_nonoverlapping(label.as_ptr(), p as *mut u8, label.len()); + + cvt(ffi::EVP_PKEY_CTX_set0_rsa_oaep_label( + self.pctx, + p as *mut c_void, + label.len() as c_int, + )) + .map(|_| ()) + .map_err(|e| { + ffi::OPENSSL_free(p); + e + }) + } + } + /// Performs public key decryption. /// /// In order to know the size needed for the output buffer, use [`decrypt_len`](Decrypter::decrypt_len). @@ -516,4 +542,37 @@ mod test { assert_eq!(decoded, &*input); } + + #[test] + #[cfg(any(ossl102, libressl310))] + fn rsa_encrypt_decrypt_oaep_label() { + let key = include_bytes!("../test/rsa.pem"); + let private_key = Rsa::private_key_from_pem(key).unwrap(); + let pkey = PKey::from_rsa(private_key).unwrap(); + + let mut encrypter = Encrypter::new(&pkey).unwrap(); + encrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap(); + encrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap(); + let input = Vec::from_hex(INPUT).unwrap(); + let buffer_len = encrypter.encrypt_len(&input).unwrap(); + let mut encoded = vec![0u8; buffer_len]; + let encoded_len = encrypter.encrypt(&input, &mut encoded).unwrap(); + let encoded = &encoded[..encoded_len]; + + let mut decrypter = Decrypter::new(&pkey).unwrap(); + decrypter.set_rsa_padding(Padding::PKCS1_OAEP).unwrap(); + decrypter.set_rsa_oaep_label(b"test_oaep_label").unwrap(); + let buffer_len = decrypter.decrypt_len(encoded).unwrap(); + let mut decoded = vec![0u8; buffer_len]; + let decoded_len = decrypter.decrypt(encoded, &mut decoded).unwrap(); + let decoded = &decoded[..decoded_len]; + + assert_eq!(decoded, &*input); + + decrypter.set_rsa_oaep_label(b"wrong_oaep_label").unwrap(); + let buffer_len = decrypter.decrypt_len(encoded).unwrap(); + let mut decoded = vec![0u8; buffer_len]; + + assert!(decrypter.decrypt(encoded, &mut decoded).is_err()); + } }