diff --git a/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs b/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs new file mode 100644 index 0000000000..19508d4e2b --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs @@ -0,0 +1,304 @@ +//! Module containing primitives pertaining to [`GLWE ciphertext +//! keyswitch`](`GlweKeyswitchKey#glwe-keyswitch`). + +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::commons::math::decomposition::{ + SignedDecomposer, SignedDecomposerNonNative, +}; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// Keyswitch a [`GLWE ciphertext`](`GlweCiphertext`) encrypted under a +/// [`GLWE secret key`](`GlweSecretKey`) to another [`GLWE secret key`](`GlweSecretKey`). +/// +/// # Formal Definition +/// +/// See [`GLWE keyswitch key`](`GlweKeyswitchKey#glwe-keyswitch`). +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweKeyswitchKey creation +/// let input_glwe_dimension = GlweDimension(2); +/// let poly_size = PolynomialSize(512); +/// let glwe_noise_distribution = Gaussian::from_dispersion_parameter( +/// StandardDev(0.00000000000000000000007069849454709433), +/// 0.0, +/// ); +/// let output_glwe_dimension = GlweDimension(1); +/// let decomp_base_log = DecompositionBaseLog(21); +/// let decomp_level_count = DecompositionLevelCount(2); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta = 1 << 59; +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// input_glwe_dimension, +/// poly_size, +/// &mut secret_generator, +/// ); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// output_glwe_dimension, +/// poly_size, +/// &mut secret_generator, +/// ); +/// +/// let ksk = allocate_and_generate_new_glwe_keyswitch_key( +/// &input_glwe_secret_key, +/// &output_glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Create the plaintext +/// let msg = 3u64; +/// let plaintext_list = PlaintextList::new(msg * delta, PlaintextCount(poly_size.0)); +/// +/// // Create a new GlweCiphertext +/// let mut input_glwe = GlweCiphertext::new( +/// 0u64, +/// input_glwe_dimension.to_glwe_size(), +/// poly_size, +/// ciphertext_modulus, +/// ); +/// +/// encrypt_glwe_ciphertext( +/// &input_glwe_secret_key, +/// &mut input_glwe, +/// &plaintext_list, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_glwe = GlweCiphertext::new( +/// 0u64, +/// output_glwe_secret_key.glwe_dimension().to_glwe_size(), +/// output_glwe_secret_key.polynomial_size(), +/// ciphertext_modulus, +/// ); +/// +/// keyswitch_glwe_ciphertext(&ksk, &input_glwe, &mut output_glwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 5 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); +/// +/// let mut output_plaintext_list = PlaintextList::new(0u64, plaintext_list.plaintext_count()); +/// +/// decrypt_glwe_ciphertext( +/// &output_glwe_secret_key, +/// &output_glwe, +/// &mut output_plaintext_list, +/// ); +/// +/// // Get the raw vector +/// let mut cleartext_list = output_plaintext_list.into_container(); +/// // Remove the encoding +/// cleartext_list +/// .iter_mut() +/// .for_each(|elt| *elt = decomposer.decode_plaintext(*elt)); +/// // Get the list immutably +/// let cleartext_list = cleartext_list; +/// +/// // Check we recovered the original message for each plaintext we encrypted +/// cleartext_list.iter().for_each(|&elt| assert_eq!(elt, msg)); +/// ``` +pub fn keyswitch_glwe_ciphertext( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + if glwe_keyswitch_key + .ciphertext_modulus() + .is_compatible_with_native_modulus() + { + keyswitch_glwe_ciphertext_native_mod_compatible( + glwe_keyswitch_key, + input_glwe_ciphertext, + output_glwe_ciphertext, + ) + } else { + keyswitch_glwe_ciphertext_other_mod( + glwe_keyswitch_key, + input_glwe_ciphertext, + output_glwe_ciphertext, + ) + } +} + +pub fn keyswitch_glwe_ciphertext_native_mod_compatible( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() + == input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched input GlweDimension. \ + GlweKeyswitchKey input GlweDimension: {:?}, input GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() + == output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched output GlweDimension. \ + GlweKeyswitchKey output GlweDimension: {:?}, output GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.output_key_glwe_dimension(), + output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_ciphertext.polynomial_size(), + "Mismatched input PolynomialSize. \ + GlweKeyswithcKey input PolynomialSize: {:?}, input GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_ciphertext.polynomial_size(), + "Mismatched output PolynomialSize. \ + GlweKeyswitchKey output PolynomialSize: {:?}, output GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size(), + ); + assert!(glwe_keyswitch_key + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new( + glwe_keyswitch_key.decomposition_base_log(), + glwe_keyswitch_key.decomposition_level_count(), + ); + + for (keyswitch_key_block, input_mask_element) in glwe_keyswitch_key + .iter() + .zip(input_glwe_ciphertext.get_mask().as_polynomial_list().iter()) + { + let mut decomposition_iter = decomposer.decompose_slice(input_mask_element.as_ref()); + // loop over the number of levels + for level_key_ciphertext in keyswitch_key_block.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + polynomial_list_wrapping_sub_scalar_mul_assign( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &Polynomial::from_container(decomposed.as_slice()), + ); + } + } +} + +pub fn keyswitch_glwe_ciphertext_other_mod( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() + == input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched input GlweDimension. \ + GlweKeyswitchKey input GlweDimension: {:?}, input GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() + == output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched output GlweDimension. \ + GlweKeyswitchKey output GlweDimension: {:?}, output GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.output_key_glwe_dimension(), + output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_ciphertext.polynomial_size(), + "Mismatched input PolynomialSize. \ + GlweKeyswithcKey input PolynomialSize: {:?}, input GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_ciphertext.polynomial_size(), + "Mismatched output PolynomialSize. \ + GlweKeyswitchKey output PolynomialSize: {:?}, output GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size(), + ); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(!ciphertext_modulus.is_compatible_with_native_modulus()); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext (no need to use non native addition here) + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposerNonNative::new( + glwe_keyswitch_key.decomposition_base_log(), + glwe_keyswitch_key.decomposition_level_count(), + ciphertext_modulus, + ); + + let mut scalar_poly = Polynomial::new(Scalar::ZERO, input_glwe_ciphertext.polynomial_size()); + + for (keyswitch_key_block, input_mask_element) in glwe_keyswitch_key + .iter() + .zip(input_glwe_ciphertext.get_mask().as_polynomial_list().iter()) + { + let mut decomposition_iter = decomposer.decompose_slice(input_mask_element.as_ref()); + // loop over the number of levels + for level_key_ciphertext in keyswitch_key_block.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + decomposed.modular_value(scalar_poly.as_mut()); + polynomial_list_wrapping_sub_scalar_mul_assign_custom_mod( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &scalar_poly, + ciphertext_modulus.get_custom_modulus().cast_into(), + ); + } + } +} diff --git a/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs b/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs new file mode 100644 index 0000000000..96dbeb455b --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs @@ -0,0 +1,352 @@ +//! Module containing primitives pertaining to [`GLWE keyswitch key generation`](`GlweKeyswitchKey`) + +use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_div_assign; +use crate::core_crypto::algorithms::*; +use crate::core_crypto::commons::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::math::decomposition::{ + DecompositionLevel, DecompositionTermSlice, DecompositionTermSliceNonNative, +}; +use crate::core_crypto::commons::math::random::{Distribution, Uniform}; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// Fill a [`GLWE keyswitch key`](`GlweKeyswitchKey`) with an actual keyswitching key constructed +/// from an input and an output key [`GLWE secret key`](`GlweSecretKey`). +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweKeyswitchKey creation +/// let input_glwe_dimension = GlweDimension(2); +/// let polynomial_size = PolynomialSize(1024); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let output_glwe_dimension = GlweDimension(1); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the GlweSecretKey +/// let input_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// input_glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// output_glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let mut ksk = GlweKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// input_glwe_dimension, +/// output_glwe_dimension, +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_glwe_keyswitch_key( +/// &input_glwe_secret_key, +/// &output_glwe_secret_key, +/// &mut ksk, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// assert!(!ksk.as_ref().iter().all(|&x| x == 0)); +/// ``` +pub fn generate_glwe_keyswitch_key< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + + if ciphertext_modulus.is_compatible_with_native_modulus() { + generate_glwe_keyswitch_key_native_mod_compatible( + input_glwe_sk, + output_glwe_sk, + glwe_keyswitch_key, + noise_distribution, + generator, + ) + } else { + generate_glwe_keyswitch_key_other_mod( + input_glwe_sk, + output_glwe_sk, + glwe_keyswitch_key, + noise_distribution, + generator, + ) + } +} + +pub fn generate_glwe_keyswitch_key_native_mod_compatible< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() == input_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey input GlweDimension is not equal \ + to the input GlweSecretKey GlweDimension. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() == output_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey output GlweDimension is not equal \ + to the output GlweSecretKey GlweDimension. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.output_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey input PolynomialSize is not equal \ + to the input GlweSecretKey PolynomialSize. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.polynomial_size(), + input_glwe_sk.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey output PolynomialSize is not equal \ + to the output GlweSecretKey PolynomialSize. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.polynomial_size(), + output_glwe_sk.polynomial_size(), + ); + + let decomp_base_log = glwe_keyswitch_key.decomposition_base_log(); + let decomp_level_count = glwe_keyswitch_key.decomposition_level_count(); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(ciphertext_modulus.is_compatible_with_native_modulus()); + + // Iterate over the input key elements and the destination glwe_keyswitch_key memory + for (input_key_polynomial, mut keyswitch_key_block) in input_glwe_sk + .as_polynomial_list() + .iter() + .zip(glwe_keyswitch_key.iter_mut()) + { + // The plaintexts used to encrypt a key element will be stored in this buffer + let mut decomposition_polynomials_buffer = PolynomialList::new( + Scalar::ZERO, + input_glwe_sk.polynomial_size(), + PolynomialCount(decomp_level_count.0), + ); + + // We fill the buffer with the powers of the key elmements + for (level, mut message_polynomial) in (1..=decomp_level_count.0) + .rev() + .map(DecompositionLevel) + .zip(decomposition_polynomials_buffer.as_mut_view().iter_mut()) + { + let term = + DecompositionTermSlice::new(level, decomp_base_log, input_key_polynomial.as_ref()); + term.fill_slice_with_recomposition_summand(message_polynomial.as_mut()); + slice_wrapping_scalar_div_assign( + message_polynomial.as_mut(), + ciphertext_modulus.get_power_of_two_scaling_to_native_torus(), + ); + } + + let decomposition_plaintexts_buffer = + PlaintextList::from_container(decomposition_polynomials_buffer.into_container()); + + encrypt_glwe_ciphertext_list( + output_glwe_sk, + &mut keyswitch_key_block, + &decomposition_plaintexts_buffer, + noise_distribution, + generator, + ); + } +} + +pub fn generate_glwe_keyswitch_key_other_mod< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() == input_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey input GlweDimension is not equal \ + to the input GlweSecretKey GlweDimension. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() == output_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey output GlweDimension is not equal \ + to the output GlweSecretKey GlweDimension. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.output_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey input PolynomialSize is not equal \ + to the input GlweSecretKey PolynomialSize. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.polynomial_size(), + input_glwe_sk.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey output PolynomialSize is not equal \ + to the output GlweSecretKey PolynomialSize. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.polynomial_size(), + output_glwe_sk.polynomial_size(), + ); + + let decomp_base_log = glwe_keyswitch_key.decomposition_base_log(); + let decomp_level_count = glwe_keyswitch_key.decomposition_level_count(); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(!ciphertext_modulus.is_compatible_with_native_modulus()); + + // Iterate over the input key elements and the destination glwe_keyswitch_key memory + for (input_key_polynomial, mut keyswitch_key_block) in input_glwe_sk + .as_polynomial_list() + .iter() + .zip(glwe_keyswitch_key.iter_mut()) + { + // The plaintexts used to encrypt a key element will be stored in this buffer + let mut decomposition_polynomials_buffer = PolynomialList::new( + Scalar::ZERO, + input_glwe_sk.polynomial_size(), + PolynomialCount(decomp_level_count.0), + ); + + // We fill the buffer with the powers of the key elmements + for (level, mut message_polynomial) in (1..=decomp_level_count.0) + .rev() + .map(DecompositionLevel) + .zip(decomposition_polynomials_buffer.as_mut_view().iter_mut()) + { + let term = DecompositionTermSliceNonNative::new( + level, + decomp_base_log, + input_key_polynomial.as_ref(), + ciphertext_modulus, + ); + term.to_approximate_recomposition_summand(message_polynomial.as_mut()); + } + + let decomposition_plaintexts_buffer = + PlaintextList::from_container(decomposition_polynomials_buffer.into_container()); + + encrypt_glwe_ciphertext_list( + output_glwe_sk, + &mut keyswitch_key_block, + &decomposition_plaintexts_buffer, + noise_distribution, + generator, + ); + } +} + +/// Allocate a new [`GLWE keyswitch key`](`GlweKeyswitchKey`) and fill it with an actual +/// keyswitching key constructed from an input and an output +/// [`GLWE secret key`](`GlweSecretKey`). +/// +/// See [`keyswitch_glwe_ciphertext`] for usage. +pub fn allocate_and_generate_new_glwe_keyswitch_key< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, + generator: &mut EncryptionRandomGenerator, +) -> GlweKeyswitchKeyOwned +where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + Gen: ByteRandomGenerator, +{ + let mut new_glwe_keyswitch_key = GlweKeyswitchKeyOwned::new( + Scalar::ZERO, + decomp_base_log, + decomp_level_count, + input_glwe_sk.glwe_dimension(), + output_glwe_sk.glwe_dimension(), + output_glwe_sk.polynomial_size(), + ciphertext_modulus, + ); + + generate_glwe_keyswitch_key( + input_glwe_sk, + output_glwe_sk, + &mut new_glwe_keyswitch_key, + noise_distribution, + generator, + ); + + new_glwe_keyswitch_key +} diff --git a/tfhe/src/core_crypto/algorithms/glwe_relinearization_key_generation.rs b/tfhe/src/core_crypto/algorithms/glwe_relinearization_key_generation.rs new file mode 100644 index 0000000000..ab67482562 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_relinearization_key_generation.rs @@ -0,0 +1,164 @@ +//! Module containing primitives pertaining to [`GLWE relinearization key +//! generation`](`GlweRelinearizationKey`). + +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::algorithms::*; +use crate::core_crypto::commons::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::math::random::{Distribution, Uniform}; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, PolynomialCount}; + +/// Fill a [`GLWE Relinearization key`](`GlweRelinearizationKey`) +/// with an actual key. +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweCiphertext creation +/// let glwe_size = GlweSize(3); +/// let polynomial_size = PolynomialSize(1024); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(7); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the GlweSecretKey +/// let glwe_secret_key: GlweSecretKey> = allocate_and_generate_new_binary_glwe_secret_key( +/// glwe_size.to_glwe_dimension(), +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let relin_key = allocate_and_generate_glwe_relinearization_key( +/// &glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// assert!(!relin_key.as_ref().iter().all(|&x| x == 0)); +/// ``` +pub fn generate_glwe_relinearization_key< + Scalar, + NoiseDistribution, + GlweKeyCont, + RelinKeyCont, + Gen, +>( + glwe_secret_key: &GlweSecretKey, + glwe_relinearization_key: &mut GlweRelinearizationKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + GlweKeyCont: Container, + RelinKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert_eq!( + glwe_secret_key.glwe_dimension(), + glwe_relinearization_key.glwe_dimension() + ); + assert_eq!( + glwe_secret_key.polynomial_size(), + glwe_relinearization_key.polynomial_size() + ); + + // We retrieve decomposition arguments + let glwe_dimension = glwe_relinearization_key.glwe_dimension(); + let decomp_level_count = glwe_relinearization_key.decomposition_level_count(); + let decomp_base_log = glwe_relinearization_key.decomposition_base_log(); + let polynomial_size = glwe_relinearization_key.polynomial_size(); + let ciphertext_modulus = glwe_relinearization_key.ciphertext_modulus(); + + // Construct the "glwe secret key" we want to keyswitch from, this is made up of the square + // and cross terms appearing when tensoring the glwe_secret_key with itself + let mut input_sk_poly_list = PolynomialList::new( + Scalar::ZERO, + polynomial_size, + PolynomialCount(glwe_dimension.0 * (glwe_dimension.0 + 1) / 2), + ); + let mut input_sk_poly_list_iter = input_sk_poly_list.iter_mut(); + + // We compute the polynomial multiplication in the same way, + // regardless of the ciphertext modulus. + for i in 0..glwe_dimension.0 { + for j in 0..i + 1 { + let mut input_key_pol = input_sk_poly_list_iter.next().unwrap(); + polynomial_wrapping_sub_mul_assign( + &mut input_key_pol, + &glwe_secret_key.as_polynomial_list().get(i), + &glwe_secret_key.as_polynomial_list().get(j), + ); + } + } + + let input_glwe_sk = GlweSecretKey::from_container(input_sk_poly_list.as_ref(), polynomial_size); + + let mut glwe_ks_key = GlweKeyswitchKey::from_container( + glwe_relinearization_key.as_mut(), + decomp_base_log, + decomp_level_count, + glwe_dimension.to_glwe_size(), + polynomial_size, + ciphertext_modulus, + ); + + generate_glwe_keyswitch_key( + &input_glwe_sk, + glwe_secret_key, + &mut glwe_ks_key, + noise_distribution, + generator, + ); +} + +pub fn allocate_and_generate_glwe_relinearization_key( + glwe_secret_key: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, + generator: &mut EncryptionRandomGenerator, +) -> GlweRelinearizationKeyOwned +where + Scalar: Encryptable, + NoiseDistribution: Distribution, + KeyCont: Container, + Gen: ByteRandomGenerator, +{ + let mut glwe_relinearization_key = GlweRelinearizationKeyOwned::new( + Scalar::ZERO, + decomp_base_log, + decomp_level_count, + glwe_secret_key.glwe_dimension().to_glwe_size(), + glwe_secret_key.polynomial_size(), + ciphertext_modulus, + ); + generate_glwe_relinearization_key( + glwe_secret_key, + &mut glwe_relinearization_key, + noise_distribution, + generator, + ); + + glwe_relinearization_key +} diff --git a/tfhe/src/core_crypto/algorithms/glwe_tensor_product.rs b/tfhe/src/core_crypto/algorithms/glwe_tensor_product.rs new file mode 100644 index 0000000000..b5038feee6 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_tensor_product.rs @@ -0,0 +1,1981 @@ +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::algorithms::slice_algorithms::*; +use crate::core_crypto::commons::math::decomposition::{ + SignedDecomposer, SignedDecomposerNonNative, +}; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::prelude::*; + +/// Converts a Scalar into a u128 via its signed value. +/// This is needed for the tensor product operation. +fn convert_scalar_to_u128(x: Scalar) -> u128 +where + Scalar: UnsignedInteger, +{ + let y = x.into_signed(); + if y < Scalar::Signed::ZERO { + let neg_x = x.wrapping_neg(); + let neg_x_u128 = >::cast_into(neg_x); + neg_x_u128.wrapping_neg() + } else { + >::cast_into(x) + } +} + +/// Converts a polynomial via the above function +fn convert_polynomial(poly: &Polynomial) -> Polynomial> +where + Scalar: UnsignedInteger, + InputCont: Container, +{ + Polynomial::from_container( + poly.as_ref() + .iter() + .map(|&x| convert_scalar_to_u128(x)) + .collect::>(), + ) +} + +/// Converts a polynomial list via the above function +fn convert_polynomial_list( + poly_list: &PolynomialList, +) -> PolynomialList> +where + Scalar: UnsignedInteger, + InputCont: Container, +{ + PolynomialList::from_container( + poly_list + .as_ref() + .iter() + .map(|&x| convert_scalar_to_u128(x)) + .collect::>(), + poly_list.polynomial_size(), + ) +} + +/// Converts a Scalar into a u128 via its signed value +/// for a custom modulus. +/// This is needed for the tensor product operation. +fn convert_scalar_to_u128_custom_mod( + x: Scalar, + ciphertext_modulus: CiphertextModulus, + square_ct_mod: u128, +) -> u128 +where + Scalar: UnsignedInteger, +{ + let custom_modulus = ciphertext_modulus + .get_custom_modulus_as_optional_scalar() + .unwrap(); + let half_ct_mod = custom_modulus / Scalar::TWO; + if x > half_ct_mod { + let neg_x = x.wrapping_neg_custom_mod(custom_modulus); + let neg_x_u128 = >::cast_into(neg_x); + neg_x_u128.wrapping_neg_custom_mod(square_ct_mod) + } else { + >::cast_into(x) + } +} + +/// Converts a polynomial via the above function +fn convert_polynomial_custom_mod( + poly: &Polynomial, + ciphertext_modulus: CiphertextModulus, + square_ct_mod: u128, +) -> Polynomial> +where + Scalar: UnsignedInteger, + InputCont: Container, +{ + Polynomial::from_container( + poly.as_ref() + .iter() + .map(|&x| convert_scalar_to_u128_custom_mod(x, ciphertext_modulus, square_ct_mod)) + .collect::>(), + ) +} + +/// Converts a polynomial list via the above function +fn convert_polynomial_list_custom_mod( + poly_list: &PolynomialList, + ciphertext_modulus: CiphertextModulus, + square_ct_mod: u128, +) -> PolynomialList> +where + Scalar: UnsignedInteger, + InputCont: Container, +{ + PolynomialList::from_container( + poly_list + .as_ref() + .iter() + .map(|&x| convert_scalar_to_u128_custom_mod(x, ciphertext_modulus, square_ct_mod)) + .collect::>(), + poly_list.polynomial_size(), + ) +} + +/// Converts a u128 to a Scalar by dividing by scale while keeping its sign. +/// This is needed for the tensor product operation. +fn scale_down_u128_to_scalar(x: u128, scale: Scalar) -> Scalar +where + Scalar: UnsignedInteger, +{ + let y = x as i128; + let scale_u128 = >::cast_into(scale); + if y < 0i128 { + let neg_x = x.wrapping_neg(); + let neg_x_scaled = Scalar::cast_from((neg_x + (scale_u128 / 2)) / scale_u128); + neg_x_scaled.wrapping_neg() + } else { + Scalar::cast_from((x + (scale_u128 / 2)) / scale_u128) + } +} + +/// Apply the above function to a polynomial component-wise. +fn scale_down_polynomial( + input_poly: &Polynomial, + output_poly: &mut Polynomial, + scale: Scalar, +) where + Scalar: UnsignedInteger, + InputCont: Container, + OutputCont: ContainerMut, +{ + output_poly + .as_mut() + .iter_mut() + .zip(input_poly.as_ref().iter()) + .for_each(|(dst, &src)| *dst = scale_down_u128_to_scalar(src, scale)); +} + +/// Converts a u128 to a Scalar by dividing by scale while keeping its sign +/// for a custom modulus. +/// This is needed for the tensor product operation. +fn scale_down_u128_to_scalar_custom_mod( + x: u128, + scale: Scalar, + ciphertext_modulus: CiphertextModulus, + square_ct_mod: u128, +) -> Scalar +where + Scalar: UnsignedInteger, +{ + let custom_modulus = ciphertext_modulus + .get_custom_modulus_as_optional_scalar() + .unwrap(); + let custom_mod_u128 = >::cast_into(custom_modulus); + let half_square_ct_mod = square_ct_mod / 2u128; + let scale_u128 = >::cast_into(scale); + if x > half_square_ct_mod { + let neg_x = x.wrapping_neg_custom_mod(square_ct_mod); + let neg_x_scaled = (neg_x + (scale_u128 / 2)) / scale_u128; + let neg_x_scaled_and_reduced = neg_x_scaled.wrapping_rem(custom_mod_u128); + Scalar::cast_from(neg_x_scaled_and_reduced).wrapping_neg_custom_mod(custom_modulus) + } else { + let x_scaled = (x + (scale_u128 / 2)) / scale_u128; + let x_scaled_and_reduced = x_scaled.wrapping_rem(custom_mod_u128); + Scalar::cast_from(x_scaled_and_reduced) + } +} + +/// Apply the above function to a polynomial component-wise. +fn scale_down_polynomial_custom_mod( + input_poly: &Polynomial, + output_poly: &mut Polynomial, + scale: Scalar, + ciphertext_modulus: CiphertextModulus, + square_ct_mod: u128, +) where + Scalar: UnsignedInteger, + InputCont: Container, + OutputCont: ContainerMut, +{ + output_poly + .as_mut() + .iter_mut() + .zip(input_poly.as_ref().iter()) + .for_each(|(dst, &src)| { + *dst = + scale_down_u128_to_scalar_custom_mod(src, scale, ciphertext_modulus, square_ct_mod) + }); +} + +/// Attempts to compute a product between u128s with a custom modulus that is also a u128. +/// This works only because the inputs are "small". +/// Further, the assumption is that the inputs are signed. +/// This product is used in the tensor product operation. +fn wrapping_mul_custom_mod_u128(lhs: u128, rhs: u128, custom_modulus: u128) -> u128 { + let half_custom_mod = custom_modulus / 2; + let lhs_neg = lhs > half_custom_mod; + let abs_lhs = if lhs_neg { + lhs.wrapping_neg_custom_mod(custom_modulus) + } else { + lhs + }; + let rhs_neg = rhs > half_custom_mod; + let abs_rhs = if rhs_neg { + rhs.wrapping_neg_custom_mod(custom_modulus) + } else { + rhs + }; + let prod_neg = lhs_neg ^ rhs_neg; + let (mut abs_prod, err) = abs_lhs.overflowing_mul(abs_rhs); + if err { + // The assumption here is that abs_lhs * abs_rhs are not much larger than 2**64. + // Thus the product is (a + b*2^64)*(c + d*2^64) = ac + (bc + ad)*2^64 + bd*2^128 + // where b and d and thus bd is very small. + // Further set bc + ad = e + f*2^64 where f should be small if b and d are. + // Then the product is ac + e*2^64 + (bd + f)*2^128 + // with (bd + f) small + // Write 2^128 = r modulo the custom modulus so that the product is + // ac + e*2^64 + (bd + f)*r for a small signed r + // This can be computed without wrapping around modulo 2^128 if + // (bd + f)*|r| < 2^128 otherwise there is an error + // There should be no error if the modulus is not close to 2^128 or is equal to + // 2^128 - r with r not close to 2^128 + let b = abs_lhs >> 64; + let d = abs_rhs >> 64; + let a = abs_lhs - (b << 64); + let c = abs_rhs - (d << 64); + let ac = a.wrapping_mul(c); // this is the first term + let ad = a.wrapping_mul(d); + let bc = b.wrapping_mul(c); + let (bcpad, err) = bc.overflowing_add(ad); + assert!(!err, "multiplication of custom u128s failed: {lhs}, {rhs}",); + let f = bcpad >> 64; + let e = bcpad - (f << 64); + let middle_term = (e << 64).wrapping_rem(custom_modulus); + let r = u128::MAX.wrapping_rem(custom_modulus).wrapping_add(1u128); + let neg_r = r > half_custom_mod; + let abs_r = if neg_r { r.wrapping_neg() } else { r }; + let bd = b.wrapping_mul(d); + let (bdpf, err) = bd.overflowing_add(f); + assert!(!err, "multiplication of custom u128s failed: {lhs}, {rhs}",); + let (bdpfr, err) = bdpf.overflowing_mul(abs_r); + assert!(!err, "multiplication of custom u128s failed: {lhs}, {rhs}",); + let last_term = bdpfr.wrapping_rem(custom_modulus); + if neg_r { + abs_prod = ac + .wrapping_add_custom_mod(middle_term, custom_modulus) + .wrapping_sub_custom_mod(last_term, custom_modulus); + } else { + abs_prod = ac + .wrapping_add_custom_mod(middle_term, custom_modulus) + .wrapping_add_custom_mod(last_term, custom_modulus); + } + } else { + abs_prod = abs_prod.wrapping_rem(custom_modulus); + } + let mut prod = abs_prod; + if prod_neg { + prod = prod.wrapping_neg_custom_mod(custom_modulus); + } + prod +} + +// What follows is a copy of the polynomial multiplication algorithms updated to use the +// above product specialised to a custom u128 modulus. + +fn polynomial_karatsuba_wrapping_mul_custom_mod_u128( + output: &mut Polynomial, + p: &Polynomial, + q: &Polynomial, + custom_modulus: u128, +) where + OutputCont: ContainerMut, + LhsCont: Container, + RhsCont: Container, +{ + // check same dimensions + assert!( + output.polynomial_size() == p.polynomial_size(), + "Output polynomial size {:?} is not the same as input lhs polynomial {:?}.", + output.polynomial_size(), + p.polynomial_size(), + ); + assert!( + output.polynomial_size() == q.polynomial_size(), + "Output polynomial size {:?} is not the same as input rhs polynomial {:?}.", + output.polynomial_size(), + q.polynomial_size(), + ); + + let poly_size = output.polynomial_size().0; + + // check dimensions are a power of 2 + assert!(poly_size.is_power_of_two()); + + // allocate slices for the rec + let mut a0 = vec![0u128; poly_size]; + let mut a1 = vec![0u128; poly_size]; + let mut a2 = vec![0u128; poly_size]; + let mut input_a2_p = vec![0u128; poly_size / 2]; + let mut input_a2_q = vec![0u128; poly_size / 2]; + + // prepare for splitting + let bottom = 0..(poly_size / 2); + let top = (poly_size / 2)..poly_size; + + // induction + induction_karatsuba_custom_mod_u128( + &mut a0, + &p[bottom.clone()], + &q[bottom.clone()], + custom_modulus, + ); + induction_karatsuba_custom_mod_u128(&mut a1, &p[top.clone()], &q[top.clone()], custom_modulus); + slice_wrapping_add_custom_mod( + &mut input_a2_p, + &p[bottom.clone()], + &p[top.clone()], + custom_modulus, + ); + slice_wrapping_add_custom_mod( + &mut input_a2_q, + &q[bottom.clone()], + &q[top.clone()], + custom_modulus, + ); + induction_karatsuba_custom_mod_u128(&mut a2, &input_a2_p, &input_a2_q, custom_modulus); + + // rebuild the result + let output: &mut [u128] = output.as_mut(); + slice_wrapping_sub_custom_mod(output, &a0, &a1, custom_modulus); + slice_wrapping_sub_assign_custom_mod( + &mut output[bottom.clone()], + &a2[top.clone()], + custom_modulus, + ); + slice_wrapping_add_assign_custom_mod( + &mut output[bottom.clone()], + &a0[top.clone()], + custom_modulus, + ); + slice_wrapping_add_assign_custom_mod( + &mut output[bottom.clone()], + &a1[top.clone()], + custom_modulus, + ); + slice_wrapping_add_assign_custom_mod( + &mut output[top.clone()], + &a2[bottom.clone()], + custom_modulus, + ); + slice_wrapping_sub_assign_custom_mod( + &mut output[top.clone()], + &a0[bottom.clone()], + custom_modulus, + ); + slice_wrapping_sub_assign_custom_mod(&mut output[top], &a1[bottom], custom_modulus); +} + +const KARATUSBA_STOP: usize = 64; +fn induction_karatsuba_custom_mod_u128( + res: &mut [u128], + p: &[u128], + q: &[u128], + custom_modulus: u128, +) { + // stop the induction when polynomials have KARATUSBA_STOP elements + if p.len() <= KARATUSBA_STOP { + // schoolbook algorithm + for (lhs_degree, &lhs_elt) in p.iter().enumerate() { + let res = &mut res[lhs_degree..]; + for (&rhs_elt, res) in q.iter().zip(res) { + *res = (*res).wrapping_add_custom_mod( + wrapping_mul_custom_mod_u128(lhs_elt, rhs_elt, custom_modulus), + custom_modulus, + ) + } + } + } else { + let poly_size = res.len(); + + // allocate slices for the rec + let mut a0 = vec![0u128; poly_size / 2]; + let mut a1 = vec![0u128; poly_size / 2]; + let mut a2 = vec![0u128; poly_size / 2]; + let mut input_a2_p = vec![0u128; poly_size / 4]; + let mut input_a2_q = vec![0u128; poly_size / 4]; + + // prepare for splitting + let bottom = 0..(poly_size / 4); + let top = (poly_size / 4)..(poly_size / 2); + + // rec + induction_karatsuba_custom_mod_u128( + &mut a0, + &p[bottom.clone()], + &q[bottom.clone()], + custom_modulus, + ); + induction_karatsuba_custom_mod_u128( + &mut a1, + &p[top.clone()], + &q[top.clone()], + custom_modulus, + ); + slice_wrapping_add_custom_mod( + &mut input_a2_p, + &p[bottom.clone()], + &p[top.clone()], + custom_modulus, + ); + slice_wrapping_add_custom_mod(&mut input_a2_q, &q[bottom], &q[top], custom_modulus); + induction_karatsuba_custom_mod_u128(&mut a2, &input_a2_p, &input_a2_q, custom_modulus); + + // rebuild the result + slice_wrapping_sub_custom_mod( + &mut res[(poly_size / 4)..(3 * poly_size / 4)], + &a2, + &a0, + custom_modulus, + ); + slice_wrapping_sub_assign_custom_mod( + &mut res[(poly_size / 4)..(3 * poly_size / 4)], + &a1, + custom_modulus, + ); + slice_wrapping_add_assign_custom_mod(&mut res[0..(poly_size / 2)], &a0, custom_modulus); + slice_wrapping_add_assign_custom_mod( + &mut res[(poly_size / 2)..poly_size], + &a1, + custom_modulus, + ); + } +} + +fn polynomial_wrapping_add_mul_schoolbook_assign_custom_mod_u128< + OutputCont, + InputCont1, + InputCont2, +>( + output: &mut Polynomial, + lhs: &Polynomial, + rhs: &Polynomial, + custom_modulus: u128, +) where + OutputCont: ContainerMut, + InputCont1: Container, + InputCont2: Container, +{ + fn implementation( + mut output: Polynomial<&mut [u128]>, + lhs: Polynomial<&[u128]>, + rhs: Polynomial<&[u128]>, + custom_modulus: u128, + ) { + let polynomial_size = output.polynomial_size(); + let degree = output.degree(); + for (lhs_degree, &lhs_coeff) in lhs.iter().enumerate() { + for (rhs_degree, &rhs_coeff) in rhs.iter().enumerate() { + let target_degree = lhs_degree + rhs_degree; + if target_degree <= degree { + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = (*output_coefficient).wrapping_add_custom_mod( + wrapping_mul_custom_mod_u128(lhs_coeff, rhs_coeff, custom_modulus), + custom_modulus, + ); + } else { + let target_degree = target_degree % polynomial_size.0; + let output_coefficient = &mut output.as_mut()[target_degree]; + + *output_coefficient = (*output_coefficient).wrapping_sub_custom_mod( + wrapping_mul_custom_mod_u128(lhs_coeff, rhs_coeff, custom_modulus), + custom_modulus, + ); + } + } + } + } + implementation( + output.as_mut_view(), + lhs.as_view(), + rhs.as_view(), + custom_modulus, + ); +} + +fn polynomial_wrapping_add_mul_assign_custom_mod_u128( + output: &mut Polynomial, + lhs: &Polynomial, + rhs: &Polynomial, + custom_modulus: u128, +) where + OutputCont: ContainerMut, + InputCont1: Container, + InputCont2: Container, +{ + assert!( + output.polynomial_size() == lhs.polynomial_size(), + "Output polynomial size {:?} is not the same as input lhs polynomial {:?}.", + output.polynomial_size(), + lhs.polynomial_size(), + ); + assert!( + output.polynomial_size() == rhs.polynomial_size(), + "Output polynomial size {:?} is not the same as input rhs polynomial {:?}.", + output.polynomial_size(), + rhs.polynomial_size(), + ); + + let polynomial_size = output.polynomial_size(); + + if polynomial_size.0.is_power_of_two() && polynomial_size.0 > KARATUSBA_STOP { + let mut tmp = Polynomial::new(0u128, polynomial_size); + + polynomial_karatsuba_wrapping_mul_custom_mod_u128(&mut tmp, lhs, rhs, custom_modulus); + polynomial_wrapping_add_assign_custom_mod(output, &tmp, custom_modulus); + } else { + polynomial_wrapping_add_mul_schoolbook_assign_custom_mod_u128( + output, + lhs, + rhs, + custom_modulus, + ) + } +} + +/// Compute the tensor product of the left-hand side [`GLWE ciphertext`](`GlweCiphertext`) with the +/// right-hand side [`GLWE ciphertext`](`GlweCiphertext`) +/// writing the result in the output [`GlweCiphertext>`](`GlweCiphertext>`). +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*; +/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweCiphertext creation +/// let glwe_size = GlweSize(3); +/// let polynomial_size = PolynomialSize(256); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let decomp_base_log = DecompositionBaseLog(21); +/// let decomp_level_count = DecompositionLevelCount(2); +/// let ciphertext_modulus = CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(); +/// +/// let delta1 = ciphertext_modulus.get_custom_modulus() as u64 / (1 << 5); +/// let delta2 = ciphertext_modulus.get_custom_modulus() as u64 / (1 << 4); +/// let delta = std::cmp::min(delta1, delta2); +/// let output_delta = std::cmp::max(delta1, delta2); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// let decomposer = SignedDecomposerNonNative::new( +/// DecompositionBaseLog(4), +/// DecompositionLevelCount(1), +/// ciphertext_modulus, +/// ); +/// +/// // Create the GlweSecretKey +/// let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// glwe_size.to_glwe_dimension(), +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// // Create the first plaintext, we encrypt a single integer rather than a general polynomial +/// let msg_1 = 3u64; +/// let encoded_msg_1 = msg_1 * delta1; +/// +/// let mut plaintext_list_1 = PlaintextList::new(0u64, PlaintextCount(polynomial_size.0)); +/// plaintext_list_1.as_mut()[0] = encoded_msg_1; +/// +/// // Create the first GlweCiphertext +/// let mut glwe_1 = GlweCiphertext::new(0u64, glwe_size, polynomial_size, ciphertext_modulus); +/// +/// encrypt_glwe_ciphertext( +/// &glwe_secret_key, +/// &mut glwe_1, +/// &plaintext_list_1, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// // Create the second plaintext +/// let msg_2 = 2u64; +/// let encoded_msg_2 = msg_2 * delta2; +/// +/// let mut plaintext_list_2 = PlaintextList::new(0u64, PlaintextCount(polynomial_size.0)); +/// plaintext_list_2.as_mut()[0] = encoded_msg_2; +/// +/// // Create the second GlweCiphertext +/// let mut glwe_2 = GlweCiphertext::new(0u64, glwe_size, polynomial_size, ciphertext_modulus); +/// +/// encrypt_glwe_ciphertext( +/// &glwe_secret_key, +/// &mut glwe_2, +/// &plaintext_list_2, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// // Create the output GlweCiphertext +/// let tensor_glwe_dim = GlweDimension((glwe_size.0 - 1) * (glwe_size.0 + 2) / 2); +/// let mut tensor_output = GlweCiphertext::new( +/// 0u64, +/// tensor_glwe_dim.to_glwe_size(), +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// // Perform the tensor product +/// glwe_tensor_product(&glwe_1, &glwe_2, &mut tensor_output, delta); +/// +/// // Compute the tensor product key +/// let mut tensor_key_poly_list = +/// PolynomialList::new(0u64, polynomial_size, PolynomialCount(tensor_glwe_dim.0)); +/// let mut key_iter = tensor_key_poly_list.iter_mut(); +/// +/// for i in 0..glwe_size.0 - 1 { +/// for j in 0..i + 1 { +/// let mut key_pol = key_iter.next().unwrap(); +/// polynomial_wrapping_sub_mul_assign_custom_mod( +/// &mut key_pol, +/// &glwe_secret_key.as_polynomial_list().get(i), +/// &glwe_secret_key.as_polynomial_list().get(j), +/// ciphertext_modulus.get_custom_modulus().cast_into(), +/// ); +/// } +/// let mut key_pol = key_iter.next().unwrap(); +/// polynomial_wrapping_add_assign_custom_mod( +/// &mut key_pol, +/// &glwe_secret_key.as_polynomial_list().get(i), +/// ciphertext_modulus.get_custom_modulus().cast_into(), +/// ); +/// } +/// +/// let tensor_key = GlweSecretKey::from_container(tensor_key_poly_list.as_ref(), polynomial_size); +/// +/// // Decrypt the tensor product ciphertext +/// let mut output_plaintext = PlaintextList::new(0u64, PlaintextCount(polynomial_size.0)); +/// +/// decrypt_glwe_ciphertext(&tensor_key, &tensor_output, &mut output_plaintext); +/// +/// // Get the raw vector +/// let mut cleartext = output_plaintext.into_container(); +/// // Remove the encoding +/// cleartext +/// .iter_mut() +/// .for_each(|elt| *elt = decomposer.decode_plaintext(*elt)); +/// // Get the list immutably +/// let cleartext = cleartext; +/// +/// // Compute what the product should be +/// let pt1 = Polynomial::from_container( +/// plaintext_list_1 +/// .into_container() +/// .iter() +/// .map(|&x| >::cast_into(x)) +/// .collect::>(), +/// ); +/// let pt2 = Polynomial::from_container( +/// plaintext_list_2 +/// .into_container() +/// .iter() +/// .map(|&x| >::cast_into(x)) +/// .collect::>(), +/// ); +/// +/// let mut product = Polynomial::new(0u128, polynomial_size); +/// polynomial_wrapping_mul(&mut product, &pt1, &pt2); +/// +/// let mut scaled_product = Polynomial::new(0u64, polynomial_size); +/// scaled_product +/// .as_mut() +/// .iter_mut() +/// .zip(product.as_ref().iter()) +/// .for_each(|(dest, &source)| { +/// *dest = +/// u64::cast_from(source / >::cast_into(delta)) / output_delta +/// }); +/// +/// // Check we recovered the correct message +/// cleartext +/// .iter() +/// .zip(scaled_product.iter()) +/// .for_each(|(&elt, coeff)| assert_eq!(elt, *coeff)); +/// +/// let glwe_relin_key = allocate_and_generate_glwe_relinearization_key( +/// &glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_glwe_ciphertext = +/// GlweCiphertext::new(0u64, glwe_size, polynomial_size, ciphertext_modulus); +/// +/// glwe_relinearization(&tensor_output, &glwe_relin_key, &mut output_glwe_ciphertext); +/// +/// // Decrypt the output glwe ciphertext +/// let mut output_plaintext = PlaintextList::new(0u64, PlaintextCount(polynomial_size.0)); +/// +/// decrypt_glwe_ciphertext( +/// &glwe_secret_key, +/// &output_glwe_ciphertext, +/// &mut output_plaintext, +/// ); +/// +/// // Get the raw vector +/// let mut cleartext = output_plaintext.into_container(); +/// // Remove the encoding +/// cleartext +/// .iter_mut() +/// .for_each(|elt| *elt = decomposer.decode_plaintext(*elt)); +/// // Get the list immutably +/// let cleartext = cleartext; +/// +/// // Check we recovered the correct message +/// cleartext +/// .iter() +/// .zip(scaled_product.iter()) +/// .for_each(|(&elt, coeff)| assert_eq!(elt, *coeff)); +/// ``` +/// based on algorithm 1 of `` +pub fn glwe_tensor_product( + input_glwe_ciphertext_lhs: &GlweCiphertext, + input_glwe_ciphertext_rhs: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, + scale: Scalar, +) where + Scalar: UnsignedInteger, + LhsCont: Container, + RhsCont: Container, + OutputCont: ContainerMut, +{ + assert!( + Scalar::BITS <= 64, + "The tensor product is not implemented for bit-widths larger than 64." + ); + + assert_eq!( + input_glwe_ciphertext_lhs.ciphertext_modulus(), + input_glwe_ciphertext_rhs.ciphertext_modulus(), + "Mismatched moduli between lhs ({:?}) and rhs ({:?}) GlweCiphertext", + input_glwe_ciphertext_lhs.ciphertext_modulus(), + input_glwe_ciphertext_rhs.ciphertext_modulus() + ); + + assert_eq!( + input_glwe_ciphertext_lhs.ciphertext_modulus(), + output_glwe_ciphertext.ciphertext_modulus(), + "Mismatched moduli between input ({:?}) and output ({:?}) GlweCiphertext", + input_glwe_ciphertext_lhs.ciphertext_modulus(), + output_glwe_ciphertext.ciphertext_modulus() + ); + + let ciphertext_modulus = input_glwe_ciphertext_lhs.ciphertext_modulus(); + + if ciphertext_modulus.is_compatible_with_native_modulus() { + glwe_tensor_product_native_mod_compatible( + input_glwe_ciphertext_lhs, + input_glwe_ciphertext_rhs, + output_glwe_ciphertext, + scale, + ) + } else { + glwe_tensor_product_other_mod( + input_glwe_ciphertext_lhs, + input_glwe_ciphertext_rhs, + output_glwe_ciphertext, + scale, + ) + } +} + +pub fn glwe_tensor_product_native_mod_compatible( + input_glwe_ciphertext_lhs: &GlweCiphertext, + input_glwe_ciphertext_rhs: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, + scale: Scalar, +) where + Scalar: UnsignedInteger, + LhsCont: Container, + RhsCont: Container, + OutputCont: ContainerMut, +{ + assert!( + input_glwe_ciphertext_lhs.polynomial_size().0 + == input_glwe_ciphertext_rhs.polynomial_size().0, + "The input glwe ciphertexts do not have the same polynomial size. The polynomial size of \ + the lhs is {}, while for the rhs it is {}.", + input_glwe_ciphertext_lhs.polynomial_size().0, + input_glwe_ciphertext_rhs.polynomial_size().0, + ); + + assert!( + input_glwe_ciphertext_lhs.polynomial_size().0 + == output_glwe_ciphertext.polynomial_size().0, + "The input glwe ciphertexts do not have the same polynomial size as the output glwe ciphertext. \ + The polynomial size of the inputs is {}, while for the output it is {}.", + input_glwe_ciphertext_lhs.polynomial_size().0, + output_glwe_ciphertext.polynomial_size().0, + ); + + assert!( + input_glwe_ciphertext_lhs.glwe_size().0 == input_glwe_ciphertext_rhs.glwe_size().0, + "The input glwe ciphertexts do not have the same glwe size. The glwe size of the lhs is \ + {}, while for the rhs it is {}.", + input_glwe_ciphertext_lhs.glwe_size().0, + input_glwe_ciphertext_rhs.glwe_size().0 + ); + + let k = input_glwe_ciphertext_lhs.glwe_size().to_glwe_dimension().0; + + // This is k + k*(k-1)/2 + k: k square terms, k*(k-1)/2 cross terms, k linear terms + let new_k = GlweDimension(k * (k + 3) / 2); + + assert!( + output_glwe_ciphertext.glwe_size().to_glwe_dimension().0 == new_k.0, + "The output glwe ciphertext does not have the correct glwe dimension. The dimension dictated by \ + the inputs is {}, while for the given output it is {}.", + new_k.0, + output_glwe_ciphertext.glwe_size().to_glwe_dimension().0, + ); + + let input_ciphertext_modulus = input_glwe_ciphertext_lhs.ciphertext_modulus(); + + assert!( + input_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" + ); + + let mut output_mask = output_glwe_ciphertext.get_mut_mask(); + let mut output_mask_poly_list = output_mask.as_mut_polynomial_list(); + let mut iter_output_mask = output_mask_poly_list.iter_mut(); + + let a_lhs = convert_polynomial_list(&input_glwe_ciphertext_lhs.get_mask().as_polynomial_list()); + let a_rhs = convert_polynomial_list(&input_glwe_ciphertext_rhs.get_mask().as_polynomial_list()); + + let b_lhs = convert_polynomial(&input_glwe_ciphertext_lhs.get_body().as_polynomial()); + let b_rhs = convert_polynomial(&input_glwe_ciphertext_rhs.get_body().as_polynomial()); + + for (i, a_lhs_i) in a_lhs.iter().enumerate() { + for (j, a_rhs_j) in a_rhs.iter().enumerate() { + if i == j { + //tensor elements corresponding to key -s_i^2 + let mut temp_poly_sq = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign(&mut temp_poly_sq, &a_lhs_i, &a_rhs_j); + + let mut output_poly_sq = iter_output_mask.next().unwrap(); + scale_down_polynomial(&temp_poly_sq, &mut output_poly_sq, scale); + + //tensor elements corresponding to key s_i + let mut temp_poly_s = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign(&mut temp_poly_s, &a_lhs_i, &b_rhs); + polynomial_wrapping_add_mul_assign(&mut temp_poly_s, &b_lhs, &a_rhs_j); + + let mut output_poly_s = iter_output_mask.next().unwrap(); + scale_down_polynomial(&temp_poly_s, &mut output_poly_s, scale); + } else { + //when i and j are different we only compute the terms where j < i + if j < i { + //tensor element corresponding to key -s_i*s_j + let mut temp_poly = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign(&mut temp_poly, &a_lhs_i, &a_rhs_j); + polynomial_wrapping_add_mul_assign( + &mut temp_poly, + &a_lhs.get(j), + &a_rhs.get(i), + ); + + let mut output_poly = iter_output_mask.next().unwrap(); + scale_down_polynomial(&temp_poly, &mut output_poly, scale); + } + } + } + } + + //tensor element corresponding to the body + let mut temp_poly_body = Polynomial::new(0u128, input_glwe_ciphertext_lhs.polynomial_size()); + polynomial_wrapping_add_mul_assign(&mut temp_poly_body, &b_lhs, &b_rhs); + let mut output_body = output_glwe_ciphertext.get_mut_body(); + let mut output_poly_body = output_body.as_mut_polynomial(); + scale_down_polynomial(&temp_poly_body, &mut output_poly_body, scale); +} + +pub fn glwe_tensor_product_other_mod( + input_glwe_ciphertext_lhs: &GlweCiphertext, + input_glwe_ciphertext_rhs: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, + scale: Scalar, +) where + Scalar: UnsignedInteger, + LhsCont: Container, + RhsCont: Container, + OutputCont: ContainerMut, +{ + assert!( + input_glwe_ciphertext_lhs.polynomial_size().0 + == input_glwe_ciphertext_rhs.polynomial_size().0, + "The input glwe ciphertexts do not have the same polynomial size. The polynomial size of \ + the lhs is {}, while for the rhs it is {}.", + input_glwe_ciphertext_lhs.polynomial_size().0, + input_glwe_ciphertext_rhs.polynomial_size().0 + ); + + assert!( + input_glwe_ciphertext_lhs.polynomial_size().0 + == output_glwe_ciphertext.polynomial_size().0, + "The input glwe ciphertexts do not have the same polynomial size as the output glwe ciphertext. \ + The polynomial size of the inputs is {}, while for the output it is {}.", + input_glwe_ciphertext_lhs.polynomial_size().0, + output_glwe_ciphertext.polynomial_size().0, + ); + + assert!( + input_glwe_ciphertext_lhs.glwe_size().0 == input_glwe_ciphertext_rhs.glwe_size().0, + "The input glwe ciphertexts do not have the same glwe size. The glwe size of the lhs is \ + {}, while for the rhs it is {}.", + input_glwe_ciphertext_lhs.glwe_size().0, + input_glwe_ciphertext_rhs.glwe_size().0 + ); + + let k = input_glwe_ciphertext_lhs.glwe_size().to_glwe_dimension().0; + + // This is k + k*(k-1)/2 + k: k square terms, k*(k-1)/2 cross terms, k linear terms + let new_k = GlweDimension(k * (k + 3) / 2); + + assert!( + output_glwe_ciphertext.glwe_size().to_glwe_dimension().0 == new_k.0, + "The output glwe ciphertext does not have the correct glwe dimension. The dimension dictated by \ + the inputs is {}, while for the given output it is {}.", + new_k.0, + output_glwe_ciphertext.glwe_size().to_glwe_dimension().0, + ); + + let ciphertext_modulus = input_glwe_ciphertext_lhs.ciphertext_modulus(); + let square_ct_mod = ciphertext_modulus.get_custom_modulus().pow(2); + + assert!( + !ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports non power of 2 moduli" + ); + + let mut output_mask = output_glwe_ciphertext.get_mut_mask(); + let mut output_mask_poly_list = output_mask.as_mut_polynomial_list(); + let mut iter_output_mask = output_mask_poly_list.iter_mut(); + + let a_lhs = convert_polynomial_list_custom_mod( + &input_glwe_ciphertext_lhs.get_mask().as_polynomial_list(), + ciphertext_modulus, + square_ct_mod, + ); + let a_rhs = convert_polynomial_list_custom_mod( + &input_glwe_ciphertext_rhs.get_mask().as_polynomial_list(), + ciphertext_modulus, + square_ct_mod, + ); + + let b_lhs = convert_polynomial_custom_mod( + &input_glwe_ciphertext_lhs.get_body().as_polynomial(), + ciphertext_modulus, + square_ct_mod, + ); + let b_rhs = convert_polynomial_custom_mod( + &input_glwe_ciphertext_rhs.get_body().as_polynomial(), + ciphertext_modulus, + square_ct_mod, + ); + + for (i, a_lhs_i) in a_lhs.iter().enumerate() { + for (j, a_rhs_j) in a_rhs.iter().enumerate() { + if i == j { + //tensor elements corresponding to key -s_i^2 + let mut temp_poly_sq = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly_sq, + &a_lhs_i, + &a_rhs_j, + square_ct_mod, + ); + + let mut output_poly_sq = iter_output_mask.next().unwrap(); + scale_down_polynomial_custom_mod( + &temp_poly_sq, + &mut output_poly_sq, + scale, + ciphertext_modulus, + square_ct_mod, + ); + + //tensor elements corresponding to key s_i + let mut temp_poly_s = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly_s, + &a_lhs_i, + &b_rhs, + square_ct_mod, + ); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly_s, + &b_lhs, + &a_rhs_j, + square_ct_mod, + ); + + let mut output_poly_s = iter_output_mask.next().unwrap(); + scale_down_polynomial_custom_mod( + &temp_poly_s, + &mut output_poly_s, + scale, + ciphertext_modulus, + square_ct_mod, + ); + } else { + //when i and j are different we only compute the terms where j < i + if j < i { + //tensor element corresponding to key -s_i*s_j + let mut temp_poly = Polynomial::new(0u128, a_lhs_i.polynomial_size()); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly, + &a_lhs_i, + &a_rhs_j, + square_ct_mod, + ); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly, + &a_lhs.get(j), + &a_rhs.get(i), + square_ct_mod, + ); + + let mut output_poly = iter_output_mask.next().unwrap(); + scale_down_polynomial_custom_mod( + &temp_poly, + &mut output_poly, + scale, + ciphertext_modulus, + square_ct_mod, + ); + } + } + } + } + + //tensor element corresponding to the body + let mut temp_poly_body = Polynomial::new(0u128, input_glwe_ciphertext_lhs.polynomial_size()); + polynomial_wrapping_add_mul_assign_custom_mod_u128( + &mut temp_poly_body, + &b_lhs, + &b_rhs, + square_ct_mod, + ); + let mut output_body = output_glwe_ciphertext.get_mut_body(); + let mut output_poly_body = output_body.as_mut_polynomial(); + scale_down_polynomial_custom_mod( + &temp_poly_body, + &mut output_poly_body, + scale, + ciphertext_modulus, + square_ct_mod, + ); +} + +/// Relinearize the [`GLWE ciphertext`](`GlweCiphertext`) that is output by the +/// glwe_tensor_product operation using a [`GLWE relinearization key`](`GlweRelinearizationKey`). +pub fn glwe_relinearization( + input_glwe_ciphertext: &GlweCiphertext, + relinearization_key: &GlweRelinearizationKey, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + InputCont: Container, + KeyCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + input_glwe_ciphertext.ciphertext_modulus(), + relinearization_key.ciphertext_modulus(), + ); + assert_eq!( + input_glwe_ciphertext.ciphertext_modulus(), + output_glwe_ciphertext.ciphertext_modulus(), + ); + + let ciphertext_modulus = input_glwe_ciphertext.ciphertext_modulus(); + + if ciphertext_modulus.is_compatible_with_native_modulus() { + glwe_relinearization_native_mod_compatible( + input_glwe_ciphertext, + relinearization_key, + output_glwe_ciphertext, + ) + } else { + glwe_relinearization_other_mod( + input_glwe_ciphertext, + relinearization_key, + output_glwe_ciphertext, + ) + } +} + +pub fn glwe_relinearization_native_mod_compatible( + input_glwe_ciphertext: &GlweCiphertext, + relinearization_key: &GlweRelinearizationKey, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + InputCont: Container, + KeyCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + relinearization_key.glwe_dimension().0 * (relinearization_key.glwe_dimension().0 + 3) / 2, + input_glwe_ciphertext.glwe_size().to_glwe_dimension().0 + ); + assert_eq!( + relinearization_key.glwe_size(), + output_glwe_ciphertext.glwe_size() + ); + assert_eq!( + relinearization_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size() + ); + assert_eq!( + relinearization_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size() + ); + assert!(relinearization_key + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new( + relinearization_key.decomposition_base_log(), + relinearization_key.decomposition_level_count(), + ); + + let mut relin_key_iter = relinearization_key.iter(); + let input_glwe_mask = input_glwe_ciphertext.get_mask(); + let input_glwe_mask_poly_list = input_glwe_mask.as_polynomial_list(); + let mut input_poly_iter = input_glwe_mask_poly_list.iter(); + + for i in 0..relinearization_key.glwe_size().0 - 1 { + for _ in 0..i + 1 { + let ksk = relin_key_iter.next().unwrap(); + let pol = input_poly_iter.next().unwrap(); + let mut decomposition_iter = decomposer.decompose_slice(pol.as_ref()); + // loop over the number of levels + for level_key_ciphertext in ksk.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + polynomial_list_wrapping_sub_scalar_mul_assign( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &Polynomial::from_container(decomposed.as_slice()), + ); + } + } + let pol = input_poly_iter.next().unwrap(); + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.as_mut_polynomial_list().get_mut(i), + &pol, + ) + } +} + +pub fn glwe_relinearization_other_mod( + input_glwe_ciphertext: &GlweCiphertext, + relinearization_key: &GlweRelinearizationKey, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + InputCont: Container, + KeyCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + relinearization_key.glwe_dimension().0 * (relinearization_key.glwe_dimension().0 + 3) / 2, + input_glwe_ciphertext.glwe_size().to_glwe_dimension().0 + ); + assert_eq!( + relinearization_key.glwe_size(), + output_glwe_ciphertext.glwe_size() + ); + assert_eq!( + relinearization_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size() + ); + assert_eq!( + relinearization_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size() + ); + + let ciphertext_modulus = input_glwe_ciphertext.ciphertext_modulus(); + + assert!( + !ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports non power of 2 moduli" + ); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext (no need to use non native addition here) + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposerNonNative::new( + relinearization_key.decomposition_base_log(), + relinearization_key.decomposition_level_count(), + ciphertext_modulus, + ); + + let mut relin_key_iter = relinearization_key.iter(); + let input_glwe_mask = input_glwe_ciphertext.get_mask(); + let input_glwe_mask_poly_list = input_glwe_mask.as_polynomial_list(); + let mut input_poly_iter = input_glwe_mask_poly_list.iter(); + let mut scalar_poly = Polynomial::new(Scalar::ZERO, input_glwe_ciphertext.polynomial_size()); + + for i in 0..relinearization_key.glwe_size().0 - 1 { + for _ in 0..i + 1 { + let ksk = relin_key_iter.next().unwrap(); + let pol = input_poly_iter.next().unwrap(); + let mut decomposition_iter = decomposer.decompose_slice(pol.as_ref()); + // loop over the number of levels + for level_key_ciphertext in ksk.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + decomposed.modular_value(scalar_poly.as_mut()); + polynomial_list_wrapping_sub_scalar_mul_assign_custom_mod( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &scalar_poly, + ciphertext_modulus.get_custom_modulus().cast_into(), + ); + } + } + let pol = input_poly_iter.next().unwrap(); + polynomial_wrapping_add_assign_custom_mod( + &mut output_glwe_ciphertext.as_mut_polynomial_list().get_mut(i), + &pol, + ciphertext_modulus.get_custom_modulus().cast_into(), + ) + } +} + +pub fn tensor_mult_with_relin( + input_glwe_ciphertext_lhs: &GlweCiphertext, + input_glwe_ciphertext_rhs: &GlweCiphertext, + scale: Scalar, + relinearization_key: &GlweRelinearizationKey, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + LhsCont: Container, + RhsCont: Container, + KeyCont: Container, + OutputCont: ContainerMut, +{ + let k = input_glwe_ciphertext_lhs.glwe_size().to_glwe_dimension().0; + let tensor_k = GlweDimension(k * (k + 3) / 2); + + let mut tensor_product_ciphertext = GlweCiphertextOwned::new( + Scalar::ZERO, + tensor_k.to_glwe_size(), + input_glwe_ciphertext_lhs.polynomial_size(), + input_glwe_ciphertext_lhs.ciphertext_modulus(), + ); + + glwe_tensor_product( + input_glwe_ciphertext_lhs, + input_glwe_ciphertext_rhs, + &mut tensor_product_ciphertext, + scale, + ); + glwe_relinearization( + &tensor_product_ciphertext, + relinearization_key, + output_glwe_ciphertext, + ); +} + +/// Compute the result of a dot product between two LWE lists +/// using the LWE Packing Keyswitch operation. +/// If we have two list of LWEs encrypting the values (v_i)_i and (w_i)_i +/// respectively this will compute an LWE ciphertext encrypting the dot product +/// of (v_i)_i and (w_i)_i, namely sum_i v_i*w_i +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let output_glwe_dimension = GlweDimension(1); +/// let output_polynomial_size = PolynomialSize(2048); +/// let decomp_base_log = DecompositionBaseLog(23); +/// let decomp_level_count = DecompositionLevelCount(1); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// output_glwe_dimension, +/// output_polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let pksk = allocate_and_generate_new_lwe_packing_keyswitch_key( +/// &input_lwe_secret_key, +/// &output_glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let number_of_lwes = 4; +/// +/// // Create the first LweCiphertextList +/// let mut input_lwe_list_lhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_lhs = +/// PlaintextList::from_container(vec![1u64 << 60, 0, 1 << 60, 1 << 60]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_lhs, +/// &input_plaintext_list_lhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// // Create the second LweCiphertextList +/// let mut input_lwe_list_rhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_rhs = +/// PlaintextList::from_container(vec![1u64 << 61, 1 << 60, 1 << 60, 0]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_rhs, +/// &input_plaintext_list_rhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let relin_key = allocate_and_generate_glwe_relinearization_key( +/// &output_glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Define the output lwe secret key +/// let output_lwe_secret_key = +/// LweSecretKey::from_container(output_glwe_secret_key.into_container()); +/// +/// // Create the output LweCiphertext +/// let mut output_lwe = LweCiphertext::new( +/// 0u64, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// ciphertext_modulus, +/// ); +/// +/// lwe_dot_product_via_packing_keyswitch( +/// &input_lwe_list_lhs, +/// &input_lwe_list_rhs, +/// &pksk, +/// &relin_key, +/// 1u64 << 60, +/// &mut output_lwe, +/// ); +/// +/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); +/// +/// let cleartext = rounded >> 60; +/// // result should be 1*2 + 0*1 + 1*1 + 1*0 = 3 +/// assert_eq!(cleartext, 3u64); +/// ``` +pub fn lwe_dot_product_via_packing_keyswitch( + input_lwe_ciphertext_list_1: &LweCiphertextList, + input_lwe_ciphertext_list_2: &LweCiphertextList, + lwe_pksk: &LwePackingKeyswitchKey, + relinearization_key: &GlweRelinearizationKey, + scale: Scalar, + output_lwe_ciphertext: &mut LweCiphertext, +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + input_lwe_ciphertext_list_1.lwe_ciphertext_count(), + input_lwe_ciphertext_list_2.lwe_ciphertext_count() + ); + + let mut packed_glwe_1 = GlweCiphertextOwned::new( + Scalar::ZERO, + lwe_pksk.output_glwe_size(), + lwe_pksk.output_polynomial_size(), + lwe_pksk.ciphertext_modulus(), + ); + keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext( + lwe_pksk, + input_lwe_ciphertext_list_1, + &mut packed_glwe_1, + ); + let mut packed_glwe_2 = GlweCiphertextOwned::new( + Scalar::ZERO, + lwe_pksk.output_glwe_size(), + lwe_pksk.output_polynomial_size(), + lwe_pksk.ciphertext_modulus(), + ); + keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext( + lwe_pksk, + input_lwe_ciphertext_list_2, + &mut packed_glwe_2, + ); + let mut relin_glwe_ciphertext = GlweCiphertextOwned::new( + Scalar::ZERO, + lwe_pksk.output_glwe_size(), + lwe_pksk.output_polynomial_size(), + lwe_pksk.ciphertext_modulus(), + ); + tensor_mult_with_relin( + &packed_glwe_1, + &packed_glwe_2, + scale, + relinearization_key, + &mut relin_glwe_ciphertext, + ); + + extract_lwe_sample_from_glwe_ciphertext( + &relin_glwe_ciphertext, + output_lwe_ciphertext, + MonomialDegree(input_lwe_ciphertext_list_1.lwe_ciphertext_count().0 - 1), + ); +} + +/// Compute the result of a dot product between two LWE lists +/// using the LWE Trace Packing Keyswitch operation, +/// If we have two lists of LWEs encrypting the values (v_i)_i and (w_i)_i +/// respectively, this will compute an LWE ciphertext encrypting the dot product +/// of (v_i)_i and (w_i)_i, namely sum_i v_i*w_i +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let output_glwe_dimension = GlweDimension(1); +/// let output_polynomial_size = PolynomialSize(2048); +/// let decomp_base_log = DecompositionBaseLog(23); +/// let decomp_level_count = DecompositionLevelCount(1); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// +/// let mut glwe_secret_key = +/// GlweSecretKey::new_empty_key(0u64, output_glwe_dimension, output_polynomial_size); +/// +/// generate_tpksk_output_glwe_secret_key( +/// &input_lwe_secret_key, +/// &mut glwe_secret_key, +/// ciphertext_modulus, +/// &mut secret_generator, +/// ); +/// +/// let mut lwe_tpksk = LweTracePackingKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// input_lwe_dimension.to_lwe_size(), +/// output_glwe_dimension.to_glwe_size(), +/// output_polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_lwe_trace_packing_keyswitch_key( +/// &glwe_secret_key, +/// &mut lwe_tpksk, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let number_of_lwes = 4; +/// +/// // Create the first LweCiphertextList +/// let mut input_lwe_list_lhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_lhs = +/// PlaintextList::from_container(vec![1u64 << 60, 0, 1 << 60, 1 << 60]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_lhs, +/// &input_plaintext_list_lhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// // Create the second LweCiphertextList +/// let mut input_lwe_list_rhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_rhs = +/// PlaintextList::from_container(vec![1u64 << 61, 1 << 60, 1 << 60, 0]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_rhs, +/// &input_plaintext_list_rhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let glwe_relin_key = allocate_and_generate_glwe_relinearization_key( +/// &glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Define the output lwe secret key +/// let output_lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); +/// +/// // Create the output LweCiphertext +/// let mut output_lwe_ciphertext = LweCiphertext::new( +/// 0_u64, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// ciphertext_modulus, +/// ); +/// +/// lwe_dot_product_via_trace_packing_keyswitch( +/// &input_lwe_list_lhs, +/// &input_lwe_list_rhs, +/// &lwe_tpksk, +/// &glwe_relin_key, +/// 1u64 << 60, +/// &mut output_lwe_ciphertext, +/// ); +/// +/// let output_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe_ciphertext); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// let rounded = decomposer.closest_representable(output_plaintext.0); +/// +/// let cleartext = rounded >> 60; +/// // result should be 1*2 + 0*1 + 1*1 + 1*0 = 3 +/// assert_eq!(cleartext, 3u64); +/// ``` +pub fn lwe_dot_product_via_trace_packing_keyswitch( + input_lwe_ciphertext_list_1: &LweCiphertextList, + input_lwe_ciphertext_list_2: &LweCiphertextList, + lwe_tpksk: &LweTracePackingKeyswitchKey, + relinearization_key: &GlweRelinearizationKey, + scale: Scalar, + output_lwe_ciphertext: &mut LweCiphertext, +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + input_lwe_ciphertext_list_1.lwe_ciphertext_count(), + input_lwe_ciphertext_list_2.lwe_ciphertext_count() + ); + + let mut packed_glwe_1 = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + let mut indices_1 = vec![0_usize; input_lwe_ciphertext_list_1.lwe_ciphertext_count().0]; + indices_1 + .iter_mut() + .enumerate() + .for_each(|(index, val)| *val = index); + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext( + lwe_tpksk, + &mut packed_glwe_1, + input_lwe_ciphertext_list_1, + &indices_1, + ); + let mut packed_glwe_2 = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + let mut indices_2 = vec![0_usize; input_lwe_ciphertext_list_2.lwe_ciphertext_count().0]; + indices_2 + .iter_mut() + .rev() + .enumerate() + .for_each(|(index, val)| *val = index); + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext( + lwe_tpksk, + &mut packed_glwe_2, + input_lwe_ciphertext_list_2, + &indices_2, + ); + let mut relin_glwe_ciphertext = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + tensor_mult_with_relin( + &packed_glwe_1, + &packed_glwe_2, + scale, + relinearization_key, + &mut relin_glwe_ciphertext, + ); + + extract_lwe_sample_from_glwe_ciphertext( + &relin_glwe_ciphertext, + output_lwe_ciphertext, + MonomialDegree(input_lwe_ciphertext_list_1.lwe_ciphertext_count().0 - 1), + ); +} + +/// Compute the result of a component-wise product of LWE lists +/// using the LWE Trace Packing Keyswitch operation. +/// If we have two lists of LWEs encrypting the values (v_i)_i and (w_i)_i +/// respectively, this will compute a list of LWE ciphertext encrypting the +/// products v_i*w_i +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let output_glwe_dimension = GlweDimension(1); +/// let output_polynomial_size = PolynomialSize(1024); +/// let decomp_base_log = DecompositionBaseLog(23); +/// let decomp_level_count = DecompositionLevelCount(1); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// +/// let mut glwe_secret_key = +/// GlweSecretKey::new_empty_key(0u64, output_glwe_dimension, output_polynomial_size); +/// +/// generate_tpksk_output_glwe_secret_key( +/// &input_lwe_secret_key, +/// &mut glwe_secret_key, +/// ciphertext_modulus, +/// &mut secret_generator, +/// ); +/// +/// let mut lwe_tpksk = LweTracePackingKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// input_lwe_dimension.to_lwe_size(), +/// output_glwe_dimension.to_glwe_size(), +/// output_polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_lwe_trace_packing_keyswitch_key( +/// &glwe_secret_key, +/// &mut lwe_tpksk, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let number_of_lwes = 4; +/// +/// // Create the first LweCiphertextList +/// let mut input_lwe_list_lhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_lhs = +/// PlaintextList::from_container(vec![1u64 << 60, 1 << 60, 2 << 60, 3 << 60]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_lhs, +/// &input_plaintext_list_lhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// // Create the second LweCiphertextList +/// let mut input_lwe_list_rhs = LweCiphertextList::new( +/// 0u64, +/// input_lwe_dimension.to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// let input_plaintext_list_rhs = +/// PlaintextList::from_container(vec![2u64 << 60, 3 << 60, 1 << 60, 0 << 60]); +/// +/// encrypt_lwe_ciphertext_list( +/// &input_lwe_secret_key, +/// &mut input_lwe_list_rhs, +/// &input_plaintext_list_rhs, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let glwe_relin_key = allocate_and_generate_glwe_relinearization_key( +/// &glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Define the output lwe secret key +/// let output_lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); +/// +/// // Create the output LweCiphertext +/// let mut output_lwe_ciphertext_list = LweCiphertextList::new( +/// 0_u64, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// LweCiphertextCount(number_of_lwes), +/// ciphertext_modulus, +/// ); +/// +/// packed_lwe_multiplication_via_trace_packing_keyswitch( +/// &input_lwe_list_lhs, +/// &input_lwe_list_rhs, +/// &lwe_tpksk, +/// &glwe_relin_key, +/// 1u64 << 60, +/// &mut output_lwe_ciphertext_list, +/// ); +/// +/// let mut output_plaintext_list = PlaintextList::new(0u64, PlaintextCount(number_of_lwes)); +/// decrypt_lwe_ciphertext_list( +/// &output_lwe_secret_key, +/// &output_lwe_ciphertext_list, +/// &mut output_plaintext_list, +/// ); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// output_plaintext_list +/// .iter_mut() +/// .for_each(|elt| *elt.0 = decomposer.closest_representable(*elt.0)); +/// +/// // Get the raw vector +/// let mut cleartext_list = output_plaintext_list.into_container(); +/// // Remove the encoding +/// cleartext_list.iter_mut().for_each(|elt| *elt >>= 60); +/// // Get the list immutably +/// let cleartext_list = cleartext_list; +/// +/// let expected_result = [2u64, 3, 2, 0]; +/// for (cleartext, expected) in cleartext_list.iter().zip(expected_result.iter()) { +/// assert_eq!(cleartext, expected); +/// } +/// ``` +pub fn packed_lwe_multiplication_via_trace_packing_keyswitch< + Scalar, + InputCont, + KeyCont, + OutputCont, +>( + input_lwe_ciphertext_list_1: &LweCiphertextList, + input_lwe_ciphertext_list_2: &LweCiphertextList, + lwe_tpksk: &LweTracePackingKeyswitchKey, + relinearization_key: &GlweRelinearizationKey, + scale: Scalar, + output_lwe_ciphertext_list: &mut LweCiphertextList, +) where + Scalar: UnsignedTorus + CastInto + CastFrom, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + let lwe_count = input_lwe_ciphertext_list_1.lwe_ciphertext_count(); + assert_eq!( + input_lwe_ciphertext_list_2.lwe_ciphertext_count(), + lwe_count + ); + assert_eq!(output_lwe_ciphertext_list.lwe_ciphertext_count(), lwe_count); + assert!( + lwe_count.0.pow(2) <= lwe_tpksk.polynomial_size().0, + "Too many input LWEs. The number of LWEs in each input lwe list must + be at most the square root of the polynomial size.", + ); + + let mut packed_glwe_1 = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + let mut indices_1 = vec![0_usize; lwe_count.0]; + indices_1 + .iter_mut() + .enumerate() + .for_each(|(index, val)| *val = index); + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext( + lwe_tpksk, + &mut packed_glwe_1, + input_lwe_ciphertext_list_1, + &indices_1, + ); + let mut packed_glwe_2 = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + let mut indices_2 = vec![0_usize; lwe_count.0]; + indices_2 + .iter_mut() + .enumerate() + .for_each(|(index, val)| *val = index * lwe_count.0); + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext( + lwe_tpksk, + &mut packed_glwe_2, + input_lwe_ciphertext_list_2, + &indices_2, + ); + let mut relin_glwe_ciphertext = GlweCiphertext::new( + Scalar::ZERO, + lwe_tpksk.output_glwe_size(), + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + tensor_mult_with_relin( + &packed_glwe_1, + &packed_glwe_2, + scale, + relinearization_key, + &mut relin_glwe_ciphertext, + ); + + output_lwe_ciphertext_list + .iter_mut() + .enumerate() + .for_each(|(iter, mut el)| { + extract_lwe_sample_from_glwe_ciphertext( + &relin_glwe_ciphertext, + &mut el, + MonomialDegree(iter * (lwe_count.0 + 1)), + ) + }); +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch.rs b/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch.rs new file mode 100644 index 0000000000..add24f963f --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch.rs @@ -0,0 +1,864 @@ +//! Module containing primitives pertaining to [`LWE trace pacling +//! keyswitch`](`LweTracePackingKeyswitchKey#lwe-trace-packing-keyswitch`). + +use crate::core_crypto::algorithms::glwe_keyswitch::*; +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// Apply a trace packing keyswitch on an input [`LWE ciphertext list`](`LweCiphertextList`) and +/// pack the result in an output [`GLWE ciphertext`](`GlweCiphertext`). +/// +/// ``` +/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweTracePackingKeyswitchKey creation +/// let lwe_dimension = LweDimension(742); +/// let lwe_count = LweCiphertextCount(2048); +/// let polynomial_size = PolynomialSize(2048); +/// let glwe_dim = (lwe_dimension.0 - 1) / polynomial_size.0 + 1; +/// let glwe_dimension = GlweDimension(glwe_dim); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000003925799891201197), 0.0); +/// let glwe_noise_distribution = Gaussian::from_dispersion_parameter( +/// StandardDev(0.00000000000000000000007069849454709433), +/// 0.0, +/// ); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta: u64 = 1 << 60; +/// +/// let mut seeder = new_seeder(); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); +/// +/// let mut glwe_secret_key = GlweSecretKey::new_empty_key(0u64, glwe_dimension, polynomial_size); +/// +/// generate_tpksk_output_glwe_secret_key( +/// &lwe_secret_key, +/// &mut glwe_secret_key, +/// ciphertext_modulus, +/// &mut secret_generator, +/// ); +/// +/// let decomp_base_log = DecompositionBaseLog(23); +/// let decomp_level_count = DecompositionLevelCount(1); +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// +/// let mut lwe_tpksk = LweTracePackingKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// lwe_dimension.to_lwe_size(), +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_lwe_trace_packing_keyswitch_key( +/// &glwe_secret_key, +/// &mut lwe_tpksk, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let mut lwe_ctxt_list = LweCiphertextList::new( +/// 0u64, +/// lwe_dimension.to_lwe_size(), +/// lwe_count, +/// ciphertext_modulus, +/// ); +/// +/// let msg = 7u64; +/// let plaintext_list = PlaintextList::new(msg * delta, PlaintextCount(lwe_count.0)); +/// +/// encrypt_lwe_ciphertext_list( +/// &lwe_secret_key, +/// &mut lwe_ctxt_list, +/// &plaintext_list, +/// lwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_glwe_ciphertext = GlweCiphertext::new( +/// 0u64, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// let mut indices = vec![0_usize; lwe_count.0]; +/// for (index, item) in indices.iter_mut().enumerate() { +/// *item = index; +/// } +/// +/// trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext( +/// &lwe_tpksk, +/// &mut output_glwe_ciphertext, +/// &lwe_ctxt_list, +/// &indices, +/// ); +/// +/// let mut output_plaintext_list = PlaintextList::new(0u64, PlaintextCount(polynomial_size.0)); +/// +/// decrypt_glwe_ciphertext( +/// &glwe_secret_key, +/// &output_glwe_ciphertext, +/// &mut output_plaintext_list, +/// ); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// // Get the raw vector +/// let mut cleartext_list = output_plaintext_list.into_container(); +/// // Remove the encoding +/// cleartext_list +/// .iter_mut() +/// .for_each(|elt| *elt = decomposer.decode_plaintext(*elt)); +/// // Get the list immutably +/// let cleartext_list = cleartext_list; +/// +/// // Check we recovered the original message for each plaintext we encrypted +/// for (index, elt) in cleartext_list.iter().enumerate() { +/// if index < lwe_count.0 { +/// assert_eq!(*elt, msg); +/// } else { +/// assert_eq!(*elt, 0); +/// } +/// } +/// ``` +pub fn trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext< + Scalar, + KeyCont, + InputCont, + OutputCont, +>( + lwe_tpksk: &LweTracePackingKeyswitchKey, + output_glwe_ciphertext: &mut GlweCiphertext, + input_lwe_ciphertext_list: &LweCiphertextList, + indices: &[usize], +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert_eq!( + input_lwe_ciphertext_list.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert_eq!( + output_glwe_ciphertext.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + + if lwe_tpksk + .ciphertext_modulus() + .is_compatible_with_native_modulus() + { + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_native_mod_compatible( + lwe_tpksk, + output_glwe_ciphertext, + input_lwe_ciphertext_list, + indices, + ) + } else { + let custom_modulus = lwe_tpksk.ciphertext_modulus().get_custom_modulus(); + if custom_modulus % 2 == 1 { + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_other_mod_odd( + lwe_tpksk, + output_glwe_ciphertext, + input_lwe_ciphertext_list, + indices, + ) + } else { + trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_other_mod_even( + lwe_tpksk, + output_glwe_ciphertext, + input_lwe_ciphertext_list, + indices, + ) + } + } +} + +pub fn trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_native_mod_compatible< + Scalar, + KeyCont, + InputCont, + OutputCont, +>( + lwe_tpksk: &LweTracePackingKeyswitchKey, + output_glwe_ciphertext: &mut GlweCiphertext, + input_lwe_ciphertext_list: &LweCiphertextList, + indices: &[usize], +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0 + <= output_glwe_ciphertext.polynomial_size().0 + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0, + indices.len() + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_size(), + lwe_tpksk.input_lwe_size() + ); + assert!(indices + .iter() + .all(|&x| x < output_glwe_ciphertext.polynomial_size().0)); + assert_eq!( + output_glwe_ciphertext.polynomial_size(), + lwe_tpksk.polynomial_size() + ); + assert_eq!( + output_glwe_ciphertext.glwe_size(), + lwe_tpksk.output_glwe_size() + ); + assert_eq!( + input_lwe_ciphertext_list.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert_eq!( + output_glwe_ciphertext.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert!(lwe_tpksk + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // We reset the output + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + let poly_size = output_glwe_ciphertext.polynomial_size(); + let glwe_size = output_glwe_ciphertext.glwe_size(); + let glwe_count = GlweCiphertextCount(poly_size.0); + let ciphertext_modulus = output_glwe_ciphertext.ciphertext_modulus(); + + let mut glwe_list = GlweCiphertextList::new( + Scalar::ZERO, + glwe_size, + poly_size, + glwe_count, + ciphertext_modulus, + ); + + // Construct the initial Glwe Ciphertexts + for (index1, mut glwe_ct) in glwe_list.iter_mut().enumerate() { + for (index2, index) in indices.iter().enumerate() { + if index1 == *index { + let lwe_ct = input_lwe_ciphertext_list.get(index2); + let lwe_body = lwe_ct.get_body(); //lwe_ct.as_ref().last().unwrap(); + let lwe_mask = lwe_ct.get_mask(); + for (index3, mut poly) in glwe_ct + .get_mut_mask() + .as_mut_polynomial_list() + .iter_mut() + .enumerate() + { + for (index4, coef) in poly.iter_mut().enumerate() { + if index3 * poly_size.0 + index4 < lwe_mask.lwe_dimension().0 { + *coef = + coef.wrapping_add(lwe_mask.as_ref()[index3 * poly_size.0 + index4]); + } + } + } + let mut glwe_body = glwe_ct.get_mut_body(); + let mut glwe_body_poly = glwe_body.as_mut_polynomial(); + glwe_body_poly[0] = *lwe_body.data; + } + } + } + + // This bit determines if we round an odd value down (if rounding_bit is zero) + // or round up (if rounding_bit is one) + // We flip this bit whenever it is used to get an rounding that is close to + // randomly rounding up or down with equal probability. + let mut rounding_bit = Scalar::ZERO; + + for l in 0..poly_size.log2().0 { + for i in 0..(poly_size.0 / 2_usize.pow(l as u32 + 1)) { + let ct_0 = glwe_list.get(i); + //let glwe_size = ct_0.glwe_size(); + let j = (poly_size.0 / 2_usize.pow(l as u32 + 1)) + i; + let ct_1 = glwe_list.get(j); + if ct_0.as_ref().iter().any(|&x| x != Scalar::ZERO) + || ct_1.as_ref().iter().any(|&x| x != Scalar::ZERO) + { + // Diving ct_0 and ct_1 by 2 + for mut pol in glwe_list.get_mut(i).as_mut_polynomial_list().iter_mut() { + pol.iter_mut().for_each(|coef| { + if *coef % Scalar::TWO == Scalar::ZERO { + *coef >>= 1 + } else { + // Round up or down depending on rounding bit + *coef = (*coef >> 1) + rounding_bit; + rounding_bit = Scalar::ONE - rounding_bit; + } + }) + } + for mut pol in glwe_list.get_mut(j).as_mut_polynomial_list().iter_mut() { + pol.iter_mut().for_each(|coef| { + if *coef % Scalar::TWO == Scalar::ZERO { + *coef >>= 1 + } else { + // Round up or down depending on rounding bit + *coef = (*coef >> 1) + rounding_bit; + rounding_bit = Scalar::ONE - rounding_bit; + } + }) + } + + // Rotate ct_1 by N/2^(l+1) + for mut pol in glwe_list.get_mut(j).as_mut_polynomial_list().iter_mut() { + polynomial_wrapping_monic_monomial_mul_assign( + &mut pol, + MonomialDegree(poly_size.0 / 2_usize.pow(l as u32 + 1)), + ); + } + + let mut ct_plus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + let mut ct_minus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + + for ((mut pol_plus, pol_0), pol_1) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_plus, &pol_0); + polynomial_wrapping_add_assign(&mut pol_plus, &pol_1); + } + + for ((mut pol_minus, pol_0), pol_1) in ct_minus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_minus, &pol_0); + polynomial_wrapping_sub_assign(&mut pol_minus, &pol_1); + } + + // Apply the automorphism sending X to X^(2^(l+1) + 1) to ct_minus + for mut pol in ct_minus.as_mut_polynomial_list().iter_mut() { + apply_automorphism_assign(&mut pol, 2_usize.pow(l as u32 + 1) + 1) + } + + let mut ks_out = GlweCiphertext::new( + Scalar::ZERO, + ct_minus.glwe_size(), + poly_size, + ciphertext_modulus, + ); + + let glwe_ksk = GlweKeyswitchKey::from_container( + lwe_tpksk.get(l).into_container(), + lwe_tpksk.decomposition_base_log(), + lwe_tpksk.decomposition_level_count(), + glwe_size, + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + + // Perform a Glwe keyswitch on ct_minus + keyswitch_glwe_ciphertext(&glwe_ksk, &ct_minus, &mut ks_out); + + // Set ct_0 to zero + glwe_list.get_mut(i).as_mut().fill(Scalar::ZERO); + + // Add the result to ct_plus and add this to ct_0 + for ((mut pol_plus, pol_ks), mut pol_0) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(ks_out.as_polynomial_list().iter()) + .zip(glwe_list.get_mut(i).as_mut_polynomial_list().iter_mut()) + { + polynomial_wrapping_add_assign(&mut pol_plus, &pol_ks); + polynomial_wrapping_add_assign(&mut pol_0, &pol_plus); + } + } + } + } + let res = glwe_list.get(0); + for (mut pol_out, pol_res) in output_glwe_ciphertext + .as_mut_polynomial_list() + .iter_mut() + .zip(res.as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_out, &pol_res); + } +} + +pub fn trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_other_mod_odd< + Scalar, + KeyCont, + InputCont, + OutputCont, +>( + lwe_tpksk: &LweTracePackingKeyswitchKey, + output_glwe_ciphertext: &mut GlweCiphertext, + input_lwe_ciphertext_list: &LweCiphertextList, + indices: &[usize], +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0 + <= output_glwe_ciphertext.polynomial_size().0 + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0, + indices.len() + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_size(), + lwe_tpksk.input_lwe_size() + ); + assert!(indices + .iter() + .all(|&x| x < output_glwe_ciphertext.polynomial_size().0)); + assert_eq!( + output_glwe_ciphertext.polynomial_size(), + lwe_tpksk.polynomial_size() + ); + assert_eq!( + output_glwe_ciphertext.glwe_size(), + lwe_tpksk.output_glwe_size() + ); + assert_eq!( + input_lwe_ciphertext_list.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert_eq!( + output_glwe_ciphertext.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert!(!lwe_tpksk + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // We reset the output + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + let poly_size = output_glwe_ciphertext.polynomial_size(); + let glwe_size = output_glwe_ciphertext.glwe_size(); + let glwe_count = GlweCiphertextCount(poly_size.0); + let ciphertext_modulus = output_glwe_ciphertext.ciphertext_modulus(); + let modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into(); + + let mut glwe_list = GlweCiphertextList::new( + Scalar::ZERO, + glwe_size, + poly_size, + glwe_count, + ciphertext_modulus, + ); + + // Construct the initial Glwe Ciphertexts + for (index1, mut glwe_ct) in glwe_list.iter_mut().enumerate() { + for (index2, index) in indices.iter().enumerate() { + if index1 == *index { + let lwe_ct = input_lwe_ciphertext_list.get(index2); + let lwe_body = lwe_ct.get_body(); + let lwe_mask = lwe_ct.get_mask(); + for (index3, mut poly) in glwe_ct + .get_mut_mask() + .as_mut_polynomial_list() + .iter_mut() + .enumerate() + { + for (index4, coef) in poly.iter_mut().enumerate() { + if index3 * poly_size.0 + index4 < lwe_mask.lwe_dimension().0 { + *coef = + coef.wrapping_add(lwe_mask.as_ref()[index3 * poly_size.0 + index4]); + } + } + } + let mut glwe_body = glwe_ct.get_mut_body(); + let mut glwe_body_poly = glwe_body.as_mut_polynomial(); + glwe_body_poly[0] = *lwe_body.data; + } + } + } + + for l in 0..poly_size.log2().0 { + for i in 0..(poly_size.0 / 2_usize.pow(l as u32 + 1)) { + let ct_0 = glwe_list.get(i); + //let glwe_size = ct_0.glwe_size(); + let j = (poly_size.0 / 2_usize.pow(l as u32 + 1)) + i; + let ct_1 = glwe_list.get(j); + if ct_0.as_ref().iter().any(|&x| x != Scalar::ZERO) + || ct_1.as_ref().iter().any(|&x| x != Scalar::ZERO) + { + // Rotate ct_1 by N/2^(l+1) + for mut pol in glwe_list.get_mut(j).as_mut_polynomial_list().iter_mut() { + polynomial_wrapping_monic_monomial_mul_assign_custom_mod( + &mut pol, + MonomialDegree(poly_size.0 / 2_usize.pow(l as u32 + 1)), + modulus_as_scalar, + ); + } + + let mut ct_plus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + let mut ct_minus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + + for ((mut pol_plus, pol_0), pol_1) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_plus, &pol_0); + polynomial_wrapping_add_assign_custom_mod( + &mut pol_plus, + &pol_1, + modulus_as_scalar, + ); + } + + for ((mut pol_minus, pol_0), pol_1) in ct_minus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_minus, &pol_0); + polynomial_wrapping_sub_assign_custom_mod( + &mut pol_minus, + &pol_1, + modulus_as_scalar, + ); + } + + // Scale the ciphertexts by 2^-1 = (q + 1)/2 when q is odd + let scalar = (modulus_as_scalar + Scalar::ONE) / Scalar::TWO; + for mut pol in ct_plus.as_mut_polynomial_list().iter_mut() { + polynomial_wrapping_scalar_mul_assign_custom_mod( + &mut pol, + scalar, + modulus_as_scalar, + ); + } + for mut pol in ct_minus.as_mut_polynomial_list().iter_mut() { + polynomial_wrapping_scalar_mul_assign_custom_mod( + &mut pol, + scalar, + modulus_as_scalar, + ); + } + + // Apply the automorphism sending X to X^(2^(l+1) + 1) to ct_minus + for mut pol in ct_minus.as_mut_polynomial_list().iter_mut() { + apply_automorphism_assign_custom_mod( + &mut pol, + 2_usize.pow(l as u32 + 1) + 1, + modulus_as_scalar, + ) + } + + let mut ks_out = GlweCiphertext::new( + Scalar::ZERO, + ct_minus.glwe_size(), + poly_size, + ciphertext_modulus, + ); + + let glwe_ksk = GlweKeyswitchKey::from_container( + lwe_tpksk.get(l).into_container(), + lwe_tpksk.decomposition_base_log(), + lwe_tpksk.decomposition_level_count(), + glwe_size, + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + + // Perform a Glwe keyswitch on ct_minus + keyswitch_glwe_ciphertext(&glwe_ksk, &ct_minus, &mut ks_out); + + // Set ct_0 to zero + glwe_list.get_mut(i).as_mut().fill(Scalar::ZERO); + + // Add the result to ct_plus and add this to ct_0 + for ((mut pol_plus, pol_ks), mut pol_0) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(ks_out.as_polynomial_list().iter()) + .zip(glwe_list.get_mut(i).as_mut_polynomial_list().iter_mut()) + { + polynomial_wrapping_add_assign_custom_mod( + &mut pol_plus, + &pol_ks, + modulus_as_scalar, + ); + polynomial_wrapping_add_assign(&mut pol_0, &pol_plus); + } + } + } + } + let res = glwe_list.get(0); + for (mut pol_out, pol_res) in output_glwe_ciphertext + .as_mut_polynomial_list() + .iter_mut() + .zip(res.as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_out, &pol_res); + } +} + +pub fn trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext_other_mod_even< + Scalar, + KeyCont, + InputCont, + OutputCont, +>( + lwe_tpksk: &LweTracePackingKeyswitchKey, + output_glwe_ciphertext: &mut GlweCiphertext, + input_lwe_ciphertext_list: &LweCiphertextList, + indices: &[usize], +) where + Scalar: UnsignedInteger, + KeyCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0 + <= output_glwe_ciphertext.polynomial_size().0 + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_ciphertext_count().0, + indices.len() + ); + assert_eq!( + input_lwe_ciphertext_list.lwe_size(), + lwe_tpksk.input_lwe_size() + ); + assert!(indices + .iter() + .all(|&x| x < output_glwe_ciphertext.polynomial_size().0)); + assert_eq!( + output_glwe_ciphertext.polynomial_size(), + lwe_tpksk.polynomial_size() + ); + assert_eq!( + output_glwe_ciphertext.glwe_size(), + lwe_tpksk.output_glwe_size() + ); + assert_eq!( + input_lwe_ciphertext_list.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert_eq!( + output_glwe_ciphertext.ciphertext_modulus(), + lwe_tpksk.ciphertext_modulus() + ); + assert!(!lwe_tpksk + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // We reset the output + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + let poly_size = output_glwe_ciphertext.polynomial_size(); + let glwe_size = output_glwe_ciphertext.glwe_size(); + let glwe_count = GlweCiphertextCount(poly_size.0); + let ciphertext_modulus = output_glwe_ciphertext.ciphertext_modulus(); + let modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into(); + + let mut glwe_list = GlweCiphertextList::new( + Scalar::ZERO, + glwe_size, + poly_size, + glwe_count, + ciphertext_modulus, + ); + + // Construct the initial Glwe Ciphertexts + for (index1, mut glwe_ct) in glwe_list.iter_mut().enumerate() { + for (index2, index) in indices.iter().enumerate() { + if index1 == *index { + let lwe_ct = input_lwe_ciphertext_list.get(index2); + let lwe_body = lwe_ct.get_body(); + let lwe_mask = lwe_ct.get_mask(); + for (index3, mut poly) in glwe_ct + .get_mut_mask() + .as_mut_polynomial_list() + .iter_mut() + .enumerate() + { + for (index4, coef) in poly.iter_mut().enumerate() { + if index3 * poly_size.0 + index4 < lwe_mask.lwe_dimension().0 { + *coef = + coef.wrapping_add(lwe_mask.as_ref()[index3 * poly_size.0 + index4]); + } + } + } + let mut glwe_body = glwe_ct.get_mut_body(); + let mut glwe_body_poly = glwe_body.as_mut_polynomial(); + glwe_body_poly[0] = *lwe_body.data; + } + } + } + + // This bit determines if we round an odd value down (if rounding_bit is zero) + // or round up (if rounding_bit is one) + // We flip this bit whenever it is used to get an rounding that is close to + // randomly rounding up or down with equal probability. + let mut rounding_bit = Scalar::ZERO; + + for l in 0..poly_size.log2().0 { + for i in 0..(poly_size.0 / 2_usize.pow(l as u32 + 1)) { + let ct_0 = glwe_list.get(i); + //let glwe_size = ct_0.glwe_size(); + let j = (poly_size.0 / 2_usize.pow(l as u32 + 1)) + i; + let ct_1 = glwe_list.get(j); + if ct_0.as_ref().iter().any(|&x| x != Scalar::ZERO) + || ct_1.as_ref().iter().any(|&x| x != Scalar::ZERO) + { + // Diving ct_0 and ct_1 by 2 + for mut pol in glwe_list.get_mut(i).as_mut_polynomial_list().iter_mut() { + pol.iter_mut().for_each(|coef| { + if *coef % Scalar::TWO == Scalar::ZERO { + *coef >>= 1 + } else { + // Round up or down depending on rounding bit + *coef = (*coef >> 1) + rounding_bit; + rounding_bit = Scalar::ONE - rounding_bit; + } + }) + } + for mut pol in glwe_list.get_mut(j).as_mut_polynomial_list().iter_mut() { + pol.iter_mut().for_each(|coef| { + if *coef % Scalar::TWO == Scalar::ZERO { + *coef >>= 1 + } else { + // Round up or down depending on rounding bit + *coef = (*coef >> 1) + rounding_bit; + rounding_bit = Scalar::ONE - rounding_bit; + } + }) + } + + // Rotate ct_1 by N/2^(l+1) + for mut pol in glwe_list.get_mut(j).as_mut_polynomial_list().iter_mut() { + polynomial_wrapping_monic_monomial_mul_assign_custom_mod( + &mut pol, + MonomialDegree(poly_size.0 / 2_usize.pow(l as u32 + 1)), + modulus_as_scalar, + ); + } + + let mut ct_plus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + let mut ct_minus = + GlweCiphertext::new(Scalar::ZERO, glwe_size, poly_size, ciphertext_modulus); + + for ((mut pol_plus, pol_0), pol_1) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_plus, &pol_0); + polynomial_wrapping_add_assign_custom_mod( + &mut pol_plus, + &pol_1, + modulus_as_scalar, + ); + } + + for ((mut pol_minus, pol_0), pol_1) in ct_minus + .as_mut_polynomial_list() + .iter_mut() + .zip(glwe_list.get(i).as_polynomial_list().iter()) + .zip(glwe_list.get(j).as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_minus, &pol_0); + polynomial_wrapping_sub_assign_custom_mod( + &mut pol_minus, + &pol_1, + modulus_as_scalar, + ); + } + + // Apply the automorphism sending X to X^(2^(l+1) + 1) to ct_minus + for mut pol in ct_minus.as_mut_polynomial_list().iter_mut() { + apply_automorphism_assign_custom_mod( + &mut pol, + 2_usize.pow(l as u32 + 1) + 1, + modulus_as_scalar, + ) + } + + let mut ks_out = GlweCiphertext::new( + Scalar::ZERO, + ct_minus.glwe_size(), + poly_size, + ciphertext_modulus, + ); + + let glwe_ksk = GlweKeyswitchKey::from_container( + lwe_tpksk.get(l).into_container(), + lwe_tpksk.decomposition_base_log(), + lwe_tpksk.decomposition_level_count(), + glwe_size, + lwe_tpksk.polynomial_size(), + lwe_tpksk.ciphertext_modulus(), + ); + + // Perform a Glwe keyswitch on ct_minus + keyswitch_glwe_ciphertext(&glwe_ksk, &ct_minus, &mut ks_out); + + // Set ct_0 to zero + glwe_list.get_mut(i).as_mut().fill(Scalar::ZERO); + + // Add the result to ct_plus and add this to ct_0 + for ((mut pol_plus, pol_ks), mut pol_0) in ct_plus + .as_mut_polynomial_list() + .iter_mut() + .zip(ks_out.as_polynomial_list().iter()) + .zip(glwe_list.get_mut(i).as_mut_polynomial_list().iter_mut()) + { + polynomial_wrapping_add_assign_custom_mod( + &mut pol_plus, + &pol_ks, + modulus_as_scalar, + ); + polynomial_wrapping_add_assign(&mut pol_0, &pol_plus); + } + } + } + } + let res = glwe_list.get(0); + for (mut pol_out, pol_res) in output_glwe_ciphertext + .as_mut_polynomial_list() + .iter_mut() + .zip(res.as_polynomial_list().iter()) + { + polynomial_wrapping_add_assign(&mut pol_out, &pol_res); + } +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch_key_generation.rs b/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch_key_generation.rs new file mode 100644 index 0000000000..a175570d41 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_trace_packing_keyswitch_key_generation.rs @@ -0,0 +1,417 @@ +//! Module containing primitives pertaining to [`LWE trace packing keyswitch key +//! generation`](`LweTracePackingKeyswitchKey`). + +use crate::core_crypto::algorithms::*; +use crate::core_crypto::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator}; +use crate::core_crypto::commons::math::random::{ + Distribution, RandomGenerable, Uniform, UniformBinary, +}; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::prelude::polynomial_algorithms::apply_automorphism_wrapping_add_assign; +use crate::core_crypto::prelude::CiphertextModulus; + +/// Fill a [`GLWE secret key`](`GlweSecretKey`) with an actual key derived from an +/// [`LWE secret key`](`LweSecretKey`) for use in the [`LWE trace packing keyswitch key`] +/// (`LweTracePackingKeyswitchKey`) +pub fn generate_tpksk_output_glwe_secret_key( + input_lwe_secret_key: &LweSecretKey, + output_glwe_secret_key: &mut GlweSecretKey, + ciphertext_modulus: CiphertextModulus, + generator: &mut SecretRandomGenerator, +) where + Scalar: RandomGenerable + UnsignedInteger, + InputKeyCont: Container, + OutputKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + let lwe_dimension = input_lwe_secret_key.lwe_dimension(); + let glwe_dimension = output_glwe_secret_key.glwe_dimension(); + let glwe_poly_size = output_glwe_secret_key.polynomial_size(); + + assert!( + lwe_dimension.0 <= glwe_dimension.0 * glwe_poly_size.0, + "Mismatched between input_lwe_secret_key dimension {:?} and number of coefficients of \ + output_glwe_secret_key {:?}.", + lwe_dimension.0, + glwe_dimension.0 * glwe_poly_size.0 + ); + + let glwe_key_container = output_glwe_secret_key.as_mut(); + + if lwe_dimension.0 < glwe_dimension.0 * glwe_poly_size.0 { + let additional_key_bits = LweSecretKey::generate_new_binary( + LweDimension(glwe_dimension.0 * glwe_poly_size.0 - lwe_dimension.0), + generator, + ); + let extended_lwe_key_iter = input_lwe_secret_key + .as_ref() + .iter() + .chain(additional_key_bits.as_ref().iter()); + for (index, lwe_key_bit) in extended_lwe_key_iter.enumerate() { + if index % glwe_poly_size.0 == 0 { + glwe_key_container[index] = *lwe_key_bit; + } else { + let rem = index % glwe_poly_size.0; + let quo = index / glwe_poly_size.0; + let new_index = (quo + 1) * glwe_poly_size.0 - rem; + if ciphertext_modulus.is_compatible_with_native_modulus() { + glwe_key_container[new_index] = lwe_key_bit.wrapping_neg(); + } else { + glwe_key_container[new_index] = lwe_key_bit.wrapping_neg_custom_mod( + ciphertext_modulus.get_custom_modulus().cast_into(), + ); + } + } + } + } else { + let extended_lwe_key_iter = input_lwe_secret_key.as_ref().iter(); + for (index, lwe_key_bit) in extended_lwe_key_iter.enumerate() { + if index % glwe_poly_size.0 == 0 { + glwe_key_container[index] = *lwe_key_bit; + } else { + let rem = index % glwe_poly_size.0; + let quo = index / glwe_poly_size.0; + let new_index = (quo + 1) * glwe_poly_size.0 - rem; + if ciphertext_modulus.is_compatible_with_native_modulus() { + glwe_key_container[new_index] = lwe_key_bit.wrapping_neg(); + } else { + glwe_key_container[new_index] = lwe_key_bit.wrapping_neg_custom_mod( + ciphertext_modulus.get_custom_modulus().cast_into(), + ); + } + } + } + } +} + +/// Fill an [`LWE trace packing keyswitch key`](`LweTracePackingKeyswitchKey`) +/// with an actual key. +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweCiphertext creation +/// let glwe_size = GlweSize(2); +/// let polynomial_size = PolynomialSize(1024); +/// let lwe_dimension = LweDimension(900); +/// let noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0); +/// let ciphertext_modulus = CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(); +/// +/// let mut seeder = new_seeder(); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); +/// +/// let mut glwe_secret_key = +/// GlweSecretKey::new_empty_key(0u64, glwe_size.to_glwe_dimension(), polynomial_size); +/// +/// generate_tpksk_output_glwe_secret_key( +/// &lwe_secret_key, +/// &mut glwe_secret_key, +/// ciphertext_modulus, +/// &mut secret_generator, +/// ); +/// +/// let decomp_base_log = DecompositionBaseLog(2); +/// let decomp_level_count = DecompositionLevelCount(8); +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// +/// let mut lwe_tpksk = LweTracePackingKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// lwe_dimension.to_lwe_size(), +/// glwe_size, +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_lwe_trace_packing_keyswitch_key( +/// &glwe_secret_key, +/// &mut lwe_tpksk, +/// noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// assert!(!lwe_tpksk.as_ref().iter().all(|&x| x == 0)); +/// ``` +pub fn generate_lwe_trace_packing_keyswitch_key< + Scalar, + NoiseDistribution, + InputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_secret_key: &GlweSecretKey, + lwe_tpksk: &mut LweTracePackingKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + let ciphertext_modulus = lwe_tpksk.ciphertext_modulus(); + if ciphertext_modulus.is_compatible_with_native_modulus() { + generate_lwe_trace_packing_keyswitch_key_native_mod_compatible( + input_glwe_secret_key, + lwe_tpksk, + noise_distribution, + generator, + ) + } else { + generate_lwe_trace_packing_keyswitch_key_other_mod( + input_glwe_secret_key, + lwe_tpksk, + noise_distribution, + generator, + ) + } +} + +pub fn generate_lwe_trace_packing_keyswitch_key_native_mod_compatible< + Scalar, + NoiseDistribution, + InputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_secret_key: &GlweSecretKey, + lwe_tpksk: &mut LweTracePackingKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert_eq!( + input_glwe_secret_key.glwe_dimension(), + lwe_tpksk.output_glwe_key_dimension() + ); + assert_eq!( + input_glwe_secret_key.polynomial_size(), + lwe_tpksk.polynomial_size() + ); + + let ciphertext_modulus = lwe_tpksk.ciphertext_modulus(); + assert!(ciphertext_modulus.is_compatible_with_native_modulus()); + + // We retrieve decomposition arguments + let glwe_dimension = lwe_tpksk.output_glwe_key_dimension(); + let decomp_level_count = lwe_tpksk.decomposition_level_count(); + let decomp_base_log = lwe_tpksk.decomposition_base_log(); + let polynomial_size = lwe_tpksk.polynomial_size(); + + let automorphism_index_iter = 1..=polynomial_size.log2().0; + + let gen_iter = generator + .try_fork_from_config(lwe_tpksk.encryption_fork_config(Uniform, noise_distribution)) + .unwrap(); + + // loop over the before key blocks + for ((auto_index, glwe_keyswitch_block), mut loop_generator) in automorphism_index_iter + .zip(lwe_tpksk.iter_mut()) + .zip(gen_iter) + { + let mut auto_glwe_sk_poly_list = PolynomialList::new( + Scalar::ZERO, + input_glwe_secret_key.polynomial_size(), + PolynomialCount(input_glwe_secret_key.glwe_dimension().0), + ); + let input_key_poly_list = input_glwe_secret_key.as_polynomial_list(); + let input_key_poly_iter = input_key_poly_list.iter(); + let auto_key_poly_iter = auto_glwe_sk_poly_list.iter_mut(); + for (mut auto_key_poly, input_key_poly) in auto_key_poly_iter.zip(input_key_poly_iter) { + apply_automorphism_wrapping_add_assign( + &mut auto_key_poly, + &input_key_poly, + 2_usize.pow(auto_index as u32) + 1, + ); + } + let mut glwe_ksk = GlweKeyswitchKey::from_container( + glwe_keyswitch_block.into_container(), + decomp_base_log, + decomp_level_count, + glwe_dimension.to_glwe_size(), + polynomial_size, + ciphertext_modulus, + ); + let auto_glwe_sk = + GlweSecretKey::from_container(auto_glwe_sk_poly_list.into_container(), polynomial_size); + generate_glwe_keyswitch_key( + &auto_glwe_sk, + input_glwe_secret_key, + &mut glwe_ksk, + noise_distribution, + &mut loop_generator, + ); + } +} + +pub fn generate_lwe_trace_packing_keyswitch_key_other_mod< + Scalar, + NoiseDistribution, + InputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_secret_key: &GlweSecretKey, + lwe_tpksk: &mut LweTracePackingKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert_eq!( + input_glwe_secret_key.glwe_dimension(), + lwe_tpksk.output_glwe_key_dimension() + ); + assert_eq!( + input_glwe_secret_key.polynomial_size(), + lwe_tpksk.polynomial_size() + ); + + let ciphertext_modulus = lwe_tpksk.ciphertext_modulus(); + assert!(!ciphertext_modulus.is_compatible_with_native_modulus()); + + // Convert the input glwe_ secret key to a polynomial list + // modulo the native modulus while keeping the sign + let mut native_glwe_secret_key_poly_list = PolynomialList::new( + Scalar::ZERO, + input_glwe_secret_key.polynomial_size(), + PolynomialCount(input_glwe_secret_key.glwe_dimension().0), + ); + // Need to go from custom to native modulus while preserving the sign + let modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into(); + input_glwe_secret_key + .as_ref() + .iter() + .zip(native_glwe_secret_key_poly_list.as_mut().iter_mut()) + .for_each(|(&src, dst)| { + if src > modulus_as_scalar / Scalar::TWO { + *dst = src.wrapping_sub(modulus_as_scalar) + } else { + *dst = src + } + }); + + // We retrieve decomposition arguments + let glwe_dimension = lwe_tpksk.output_glwe_key_dimension(); + let decomp_level_count = lwe_tpksk.decomposition_level_count(); + let decomp_base_log = lwe_tpksk.decomposition_base_log(); + let polynomial_size = lwe_tpksk.polynomial_size(); + + let automorphism_index_iter = 1..=polynomial_size.log2().0; + + let gen_iter = generator + .try_fork_from_config(lwe_tpksk.encryption_fork_config(Uniform, noise_distribution)) + .unwrap(); + + // loop over the before key blocks + for ((auto_index, glwe_keyswitch_block), mut loop_generator) in automorphism_index_iter + .zip(lwe_tpksk.iter_mut()) + .zip(gen_iter) + { + let mut auto_glwe_sk_poly_list = PolynomialList::new( + Scalar::ZERO, + input_glwe_secret_key.polynomial_size(), + PolynomialCount(input_glwe_secret_key.glwe_dimension().0), + ); + let native_key_poly_iter = native_glwe_secret_key_poly_list.iter(); + let auto_key_poly_iter = auto_glwe_sk_poly_list.iter_mut(); + for (mut auto_key_poly, native_key_poly) in auto_key_poly_iter.zip(native_key_poly_iter) { + apply_automorphism_wrapping_add_assign( + &mut auto_key_poly, + &native_key_poly, + 2_usize.pow(auto_index as u32) + 1, + ); + } + + let mut glwe_ksk = GlweKeyswitchKey::from_container( + glwe_keyswitch_block.into_container(), + decomp_base_log, + decomp_level_count, + glwe_dimension.to_glwe_size(), + polynomial_size, + ciphertext_modulus, + ); + let auto_glwe_sk = + GlweSecretKey::from_container(auto_glwe_sk_poly_list.into_container(), polynomial_size); + generate_glwe_keyswitch_key( + &auto_glwe_sk, + input_glwe_secret_key, + &mut glwe_ksk, + noise_distribution, + &mut loop_generator, + ); + } +} + +/// Allocate a new [`LWE trace packing keyswitch key`](`LweTracePackingKeyswitchKey`) and fill it +/// with an actual trace packing keyswitching key constructed from an associated input [`GLWE secret +/// key`](`GlweSecretKey`). +/// +/// See [`generate_tpksk_output_glwe_secret_key`](`generate_tpksk_output_glwe_secret_key`) +/// for more details. +/// +/// See [`trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext`](`super::trace_packing_keyswitch_lwe_ciphertext_list_into_glwe_ciphertext`) +/// for usage. +pub fn allocate_and_generate_new_lwe_trace_packing_keyswitch_key< + Scalar, + NoiseDistribution, + KeyCont, + Gen, +>( + lwe_size: LweSize, + glwe_secret_key: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, + generator: &mut EncryptionRandomGenerator, +) -> LweTracePackingKeyswitchKeyOwned +where + Scalar: Encryptable, + NoiseDistribution: Distribution, + KeyCont: Container, + Gen: ByteRandomGenerator, +{ + let mut new_lwe_trace_packing_keyswitch_key = LweTracePackingKeyswitchKeyOwned::new( + Scalar::ZERO, + decomp_base_log, + decomp_level_count, + lwe_size, + glwe_secret_key.glwe_dimension().to_glwe_size(), + glwe_secret_key.polynomial_size(), + ciphertext_modulus, + ); + + generate_lwe_trace_packing_keyswitch_key( + glwe_secret_key, + &mut new_lwe_trace_packing_keyswitch_key, + noise_distribution, + generator, + ); + + new_lwe_trace_packing_keyswitch_key +} diff --git a/tfhe/src/core_crypto/algorithms/mod.rs b/tfhe/src/core_crypto/algorithms/mod.rs index a94edac4ce..6e4996202b 100644 --- a/tfhe/src/core_crypto/algorithms/mod.rs +++ b/tfhe/src/core_crypto/algorithms/mod.rs @@ -5,9 +5,13 @@ pub mod ggsw_conversion; pub mod ggsw_encryption; pub mod glwe_encryption; +pub mod glwe_keyswitch; +pub mod glwe_keyswitch_key_generation; pub mod glwe_linear_algebra; +pub mod glwe_relinearization_key_generation; pub mod glwe_sample_extraction; pub mod glwe_secret_key_generation; +pub mod glwe_tensor_product; pub mod lwe_bootstrap_key_conversion; pub mod lwe_bootstrap_key_generation; pub mod lwe_compact_ciphertext_list_expansion; @@ -26,6 +30,8 @@ pub mod lwe_private_functional_packing_keyswitch_key_generation; pub mod lwe_programmable_bootstrapping; pub mod lwe_public_key_generation; pub mod lwe_secret_key_generation; +pub mod lwe_trace_packing_keyswitch; +pub mod lwe_trace_packing_keyswitch_key_generation; pub mod lwe_wopbs; #[cfg(feature = "zk-pok")] pub mod lwe_zero_knowledge_verification; @@ -53,9 +59,13 @@ pub(crate) mod test; pub use ggsw_conversion::*; pub use ggsw_encryption::*; pub use glwe_encryption::*; +pub use glwe_keyswitch::*; +pub use glwe_keyswitch_key_generation::*; pub use glwe_linear_algebra::*; +pub use glwe_relinearization_key_generation::*; pub use glwe_sample_extraction::*; pub use glwe_secret_key_generation::*; +pub use glwe_tensor_product::*; pub use lwe_bootstrap_key_conversion::*; pub use lwe_bootstrap_key_generation::*; pub use lwe_compact_ciphertext_list_expansion::*; @@ -74,6 +84,8 @@ pub use lwe_private_functional_packing_keyswitch_key_generation::*; pub use lwe_programmable_bootstrapping::*; pub use lwe_public_key_generation::*; pub use lwe_secret_key_generation::*; +pub use lwe_trace_packing_keyswitch::*; +pub use lwe_trace_packing_keyswitch_key_generation::*; pub use lwe_wopbs::*; #[cfg(feature = "zk-pok")] pub use lwe_zero_knowledge_verification::*; diff --git a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs index 43e4c39304..dad07497f5 100644 --- a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs +++ b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs @@ -334,6 +334,30 @@ pub fn polynomial_wrapping_add_mul_assign_custom_mod( + output: &mut Polynomial, + scalar: Scalar, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + PolyCont: ContainerMut, +{ + slice_wrapping_scalar_mul_assign_custom_mod(output.as_mut(), scalar, custom_modulus) +} + /// Divides (mod $(X^{N}+1)$), the output polynomial with a monic monomial of a given degree i.e. /// $X^{degree}$. /// @@ -919,6 +943,224 @@ pub fn polynomial_wrapping_sub_mul_assign_custom_mod( + output_poly_list: &mut PolynomialList, + input_poly_list: &PolynomialList, + scalar_poly: &Polynomial, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont: Container, + PolyCont: Container, +{ + assert_eq!( + output_poly_list.polynomial_size(), + input_poly_list.polynomial_size() + ); + assert_eq!( + output_poly_list.polynomial_count(), + input_poly_list.polynomial_count() + ); + for (mut output_poly, input_poly) in output_poly_list.iter_mut().zip(input_poly_list.iter()) { + polynomial_wrapping_sub_mul_assign(&mut output_poly, &input_poly, scalar_poly) + } +} + +pub fn polynomial_list_wrapping_sub_scalar_mul_assign_custom_mod< + Scalar, + InputCont, + OutputCont, + PolyCont, +>( + output_poly_list: &mut PolynomialList, + input_poly_list: &PolynomialList, + scalar_poly: &Polynomial, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont: Container, + PolyCont: Container, +{ + assert_eq!( + output_poly_list.polynomial_size(), + input_poly_list.polynomial_size() + ); + assert_eq!( + output_poly_list.polynomial_count(), + input_poly_list.polynomial_count() + ); + for (mut output_poly, input_poly) in output_poly_list.iter_mut().zip(input_poly_list.iter()) { + polynomial_wrapping_sub_mul_assign_custom_mod( + &mut output_poly, + &input_poly, + scalar_poly, + custom_modulus, + ) + } +} + +/// Apply an automorphism to the input [`Polynomial`](`Polynomial`) and add +/// the result to the output [`Polynomial`](`Polynomial`). +/// +/// The automorphism is specified by the exponent to which the polynomial +/// indeterminate is raised, namely the value e where X is mapped to X^e. +/// The automorphism exponent needs to be odd as we assume we are working +/// in a power of two cyclotomic ring. +pub fn apply_automorphism_wrapping_add_assign( + output: &mut Polynomial, + input: &Polynomial, + automorphism_exponent: usize, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + PolyCont: Container, +{ + // check input and output polynomials have the same size + assert_eq!(input.polynomial_size(), output.polynomial_size()); + + // check dimensions are a power of 2 + assert!(input.polynomial_size().0.is_power_of_two()); + + // check the automorphism exponent is odd so the function X -> X^automorphism_exponent is an + // automorphism (assumes polysize is a power of 2 which we just checked) + assert_eq!(automorphism_exponent % 2, 1); + + let poly_size = input.polynomial_size().0; + + for (index, coef) in input.iter().enumerate() { + let new_index = (index * automorphism_exponent) % poly_size; + if (index * automorphism_exponent) % (2 * poly_size) == new_index { + output[new_index] = output[new_index].wrapping_add(*coef); + } else { + output[new_index] = output[new_index].wrapping_sub(*coef); + } + } +} + +/// Apply an automorphism to the input [`Polynomial`](`Polynomial`) and add +/// the result to the output [`Polynomial`](`Polynomial`) modulo a custom modulus. +/// +/// The automorphism is specified by the exponent to which the polynomial +/// indeterminate is raised, namely the value e where X is mapped to X^e. +/// The automorphism exponent needs to be odd as we assume we are working +/// in a power of two cyclotomic ring. +pub fn apply_automorphism_wrapping_add_assign_custom_mod( + output: &mut Polynomial, + input: &Polynomial, + automorphism_exponent: usize, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + PolyCont: Container, +{ + // check input and output polynomials have the same size + assert_eq!(input.polynomial_size(), output.polynomial_size()); + + // check dimensions are a power of 2 + assert!(input.polynomial_size().0.is_power_of_two()); + + // check the automorphism exponent is odd so the function X -> X^automorphism_exponent is an + // automorphism (assumes polysize is a power of 2 which we just checked) + assert_eq!(automorphism_exponent % 2, 1); + + let poly_size = input.polynomial_size().0; + + for (index, coef) in input.iter().enumerate() { + let new_index = (index * automorphism_exponent) % poly_size; + if (index * automorphism_exponent) % (2 * poly_size) == new_index { + output[new_index] = output[new_index].wrapping_add_custom_mod(*coef, custom_modulus); + } else { + output[new_index] = output[new_index].wrapping_sub_custom_mod(*coef, custom_modulus); + } + } +} + +/// Apply an automorphism to the input [`Polynomial`](`Polynomial`). +/// +/// The automorphism is specified by the exponent to which the polynomial +/// indeterminate is raised, namely the value e where X is mapped to X^e. +/// The automorphism exponent needs to be odd as we assume we are working +/// in a power of two cyclotomic ring. +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*; +/// use tfhe::core_crypto::commons::parameters::*; +/// use tfhe::core_crypto::entities::*; +/// let mut poly = Polynomial::new(1_u8, PolynomialSize(32)); +/// apply_automorphism_assign(&mut poly, 5); +/// let expected = [ +/// 1u8, 1, 1, 255, 255, 1, 1, 1, 255, 255, 1, 1, 1, 255, 255, 1, 1, 1, 255, 255, 1, 1, 1, 255, +/// 255, 1, 1, 1, 255, 255, 1, 1, +/// ]; +/// +/// poly.as_ref() +/// .iter() +/// .zip(expected.iter()) +/// .for_each(|(&x, &y)| assert_eq!(x, y)); +/// ``` +pub fn apply_automorphism_assign( + input: &mut Polynomial, + automorphism_exponent: usize, +) where + Scalar: UnsignedInteger, + PolyCont: ContainerMut, +{ + let mut temp = Polynomial::new(Scalar::ZERO, input.polynomial_size()); + apply_automorphism_wrapping_add_assign(&mut temp, input, automorphism_exponent); + input.as_mut().fill(Scalar::ZERO); + polynomial_wrapping_add_assign(input, &temp); +} + +/// Apply an automorphism to the input [`Polynomial`](`Polynomial`) +/// modulo a custom modulus. +/// +/// The automorphism is specified by the exponent to which the polynomial +/// indeterminate is raised, namely the value e where X is mapped to X^e. +/// The automorphism exponent needs to be odd as we assume we are working +/// in a power of two cyclotomic ring. +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*; +/// use tfhe::core_crypto::commons::parameters::*; +/// use tfhe::core_crypto::entities::*; +/// let mut poly = Polynomial::new(1_u8, PolynomialSize(32)); +/// let custom_modulus = 223u8; +/// apply_automorphism_assign_custom_mod(&mut poly, 5, custom_modulus); +/// let expected = [ +/// 1u8, 1, 1, 222, 222, 1, 1, 1, 222, 222, 1, 1, 1, 222, 222, 1, 1, 1, 222, 222, 1, 1, 1, 222, +/// 222, 1, 1, 1, 222, 222, 1, 1, +/// ]; +/// +/// poly.as_ref() +/// .iter() +/// .zip(expected.iter()) +/// .for_each(|(&x, &y)| assert_eq!(x, y)); +/// ``` +pub fn apply_automorphism_assign_custom_mod( + input: &mut Polynomial, + automorphism_exponent: usize, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + PolyCont: ContainerMut, +{ + let mut temp = Polynomial::new(Scalar::ZERO, input.polynomial_size()); + apply_automorphism_wrapping_add_assign_custom_mod( + &mut temp, + input, + automorphism_exponent, + custom_modulus, + ); + input.as_mut().fill(Scalar::ZERO); + polynomial_wrapping_add_assign_custom_mod(input, &temp, custom_modulus); +} + /// Fill the output polynomial, with the result of the product of two polynomials, reduced modulo /// $(X^{N} + 1)$ with the schoolbook algorithm Complexity: $O(N^{2})$ /// diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 82ff9bf277..2becc3b330 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -1,6 +1,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus; use crate::core_crypto::commons::math::decomposition::{ - SignedDecompositionIter, SignedDecompositionNonNativeIter, ValueSign, + SignedDecompositionIter, SignedDecompositionNonNativeIter, SliceSignedDecompositionIter, + SliceSignedDecompositionNonNativeIter, ValueSign, }; use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger}; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; @@ -125,6 +126,36 @@ where native_closest_representable(input, self.level_count, self.base_log) } + /// Decode a plaintext value using the decoder to compute the closest representable. + pub fn decode_plaintext(&self, input: Scalar) -> Scalar { + let shift = Scalar::BITS - self.level_count * self.base_log; + self.closest_representable(input) >> shift + } + + /// Fills a mutable tensor-like objects with the closest representable values from another + /// tensor-like object. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// + /// let input = vec![1_340_987_234_u32; 2]; + /// let mut closest = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut closest, &input); + /// assert!(closest.iter().all(|&x| x == 1_341_128_704_u32)); + /// ``` + pub fn fill_slice_with_closest_representable(&self, output: &mut [Scalar], input: &[Scalar]) { + assert_eq!(output.len(), input.len()); + output + .iter_mut() + .zip(input.iter()) + .for_each(|(dst, &src)| *dst = self.closest_representable(src)); + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -193,6 +224,89 @@ where None } } + + /// Generates an iterator-like object over tensors of terms of the decomposition of the input + /// tensor. + /// + /// # Warning + /// + /// The returned iterator yields the terms $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ in + /// order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32, 1_340_987_234_u32]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// + /// let mut count = 0; + /// while let Some(term) = decomp.next_term() { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// for elmt in term.as_slice().iter() { + /// let signed_term = elmt.into_signed(); + /// let half_basis = 2i32.pow(4) / 2i32; + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term < half_basis); + /// } + /// count += 1; + /// } + /// assert_eq!(count, 3); + /// ``` + pub fn decompose_slice(&self, input: &[Scalar]) -> SliceSignedDecompositionIter { + // Note that there would be no sense of making the decomposition on an input which was + // not rounded to the closest representable first. We then perform it before decomposing. + let mut closest = vec![Scalar::ZERO; input.len()]; + self.fill_slice_with_closest_representable(&mut closest, input); + SliceSignedDecompositionIter::new( + &closest, + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + ) + } + + /// Fills the output tensor with the recomposition of another tensor. + /// + /// Returns `Some(())` if the decomposition was fresh, and the output was filled with a + /// recomposition, and `None`, if not. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut rounded = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut rounded, &decomposable); + /// let decomp = decomposer.decompose_slice(&rounded); + /// let mut recomposition = vec![0u32; 2]; + /// decomposer + /// .fill_slice_with_recompose(decomp, &mut recomposition) + /// .unwrap(); + /// assert_eq!(recomposition, rounded); + /// ``` + pub fn fill_slice_with_recompose( + &self, + decomp: SliceSignedDecompositionIter, + output: &mut [Scalar], + ) -> Option<()> { + let mut decomp = decomp; + if decomp.is_fresh() { + while let Some(term) = decomp.next_term() { + term.update_slice_with_recomposition_summand_wrapping_addition(output); + } + Some(()) + } else { + None + } + } } /// A structure which allows to decompose unsigned integers into a set of smaller terms for moduli @@ -388,6 +502,26 @@ where } } + /// Decode a plaintext value using the decoder modulo a custom modulus. + pub fn decode_plaintext(&self, input: Scalar) -> Scalar { + let ciphertext_modulus_as_scalar: Scalar = + self.ciphertext_modulus.get_custom_modulus().cast_into(); + let mut negate_input = false; + let mut ptxt = input; + if input > ciphertext_modulus_as_scalar >> 1 { + negate_input = true; + ptxt = ptxt.wrapping_neg_custom_mod(ciphertext_modulus_as_scalar); + } + let number_of_message_bits = self.base_log().0 * self.level_count().0; + let delta = ciphertext_modulus_as_scalar >> number_of_message_bits; + let half_delta = delta >> 1; + let mut decoded = (ptxt + half_delta) / delta; + if negate_input { + decoded = decoded.wrapping_neg_custom_mod(ciphertext_modulus_as_scalar); + } + decoded + } + #[inline(always)] pub fn init_decomposer_state(&self, input: Scalar) -> (Scalar, ValueSign) { let ciphertext_modulus_as_scalar: Scalar = @@ -419,6 +553,36 @@ where (abs_closest_representable, input_sign) } + pub fn init_decomposer_state_slice( + &self, + input: &[Scalar], + output: &mut [Scalar], + signs: &mut [ValueSign], + ) { + assert_eq!(input.len(), output.len()); + assert_eq!(input.len(), signs.len()); + let ciphertext_modulus_as_scalar: Scalar = + self.ciphertext_modulus.get_custom_modulus().cast_into(); + let shift_to_native = Scalar::BITS - self.ciphertext_modulus_bit_count() as usize; + + input + .iter() + .zip(output.iter_mut()) + .zip(signs.iter_mut()) + .for_each(|((input, output), sign)| { + if *input < ciphertext_modulus_as_scalar.div_ceil(Scalar::TWO) { + (*output, *sign) = (*input, ValueSign::Positive) + } else { + (*output, *sign) = (ciphertext_modulus_as_scalar - *input, ValueSign::Negative) + }; + *output = native_closest_representable( + *output << shift_to_native, + self.level_count, + self.base_log, + ) >> shift_to_native + }); + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -516,4 +680,151 @@ where None } } + + /// Fills a mutable tensor-like objects with the closest representable values from another + /// tensor-like object. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{SignedDecomposerNonNative, ValueSign}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 48) + 1).unwrap(), + /// ); + /// + /// let input = vec![249280154129830u64; 2]; + /// let mut closest = vec![0u64; 2]; + /// let mut signs = vec![ValueSign::Positive; 2]; + /// decomposer.init_decomposer_state_slice(&input, &mut closest, &mut signs); + /// assert!(closest.iter().all(|&x| x == 32160715112448u64)); + /// decomposer.fill_slice_with_closest_representable(&mut closest, &input); + /// assert!(closest.iter().all(|&x| x == 249314261598209u64)); + /// ``` + pub fn fill_slice_with_closest_representable(&self, output: &mut [Scalar], input: &[Scalar]) { + assert_eq!(output.len(), input.len()); + let mut signs = vec![ValueSign::Positive; input.len()]; + self.init_decomposer_state_slice(input, output, &mut signs); + + let modulus_as_scalar: Scalar = self.ciphertext_modulus.get_custom_modulus().cast_into(); + output + .iter_mut() + .zip(signs.iter()) + .for_each(|(output, sign)| match sign { + ValueSign::Positive => (), + ValueSign::Negative => *output = output.wrapping_neg_custom_mod(modulus_as_scalar), + }); + } + + /// Generates an iterator-like object over tensors of terms of the decomposition of the input + /// tensor. + /// + /// # Warning + /// + /// The returned iterator yields the terms $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ in + /// order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// + /// let decomposition_base_log = DecompositionBaseLog(4); + /// let decomposition_level_count = DecompositionLevelCount(3); + /// let ciphertext_modulus = CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(); + /// + /// let decomposer = SignedDecomposerNonNative::new( + /// decomposition_base_log, + /// decomposition_level_count, + /// ciphertext_modulus, + /// ); + /// + /// let basis = 2i64.pow(decomposition_base_log.0.try_into().unwrap()); + /// let half_basis = basis / 2; + /// + /// let decomposable = [9223372032559808513u64, 1u64 << 63]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// + /// let mut count = 0; + /// while let Some(term) = decomp.next_term() { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// for elmt in term.as_slice().iter() { + /// let signed_term = elmt.into_signed(); + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term <= half_basis); + /// } + /// count += 1; + /// } + /// assert_eq!(count, 3); + /// ``` + pub fn decompose_slice( + &self, + input: &[Scalar], + ) -> SliceSignedDecompositionNonNativeIter { + let mut abs_closest_representables = vec![Scalar::ZERO; input.len()]; + let mut signs = vec![ValueSign::Positive; input.len()]; + self.init_decomposer_state_slice(input, &mut abs_closest_representables, &mut signs); + + SliceSignedDecompositionNonNativeIter::new( + &abs_closest_representables, + &signs, + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + self.ciphertext_modulus, + ) + } + + /// Fills the output tensor with the recomposition of another tensor. + /// + /// Returns `Some(())` if the decomposition was fresh, and the output was filled with a + /// recomposition, and `None`, if not. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// + /// let ciphertext_modulus = CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(); + /// let decomposer = SignedDecomposerNonNative::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// ciphertext_modulus, + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut rounded = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut rounded, &decomposable); + /// let decomp = decomposer.decompose_slice(&rounded); + /// let mut recomposition = vec![0u32; 2]; + /// decomposer + /// .fill_slice_with_recompose(decomp, &mut recomposition) + /// .unwrap(); + /// assert_eq!(recomposition, rounded); + /// ``` + pub fn fill_slice_with_recompose( + &self, + decomp: SliceSignedDecompositionNonNativeIter, + output: &mut [Scalar], + ) -> Option<()> { + let mut decomp = decomp; + if decomp.is_fresh() { + while let Some(term) = decomp.next_term() { + term.update_slice_with_recomposition_summand_wrapping_addition(output); + } + Some(()) + } else { + None + } + } } diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 3d91888976..4b465d195c 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -1,6 +1,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus; use crate::core_crypto::commons::math::decomposition::{ - DecompositionLevel, DecompositionTerm, DecompositionTermNonNative, SignedDecomposerNonNative, + DecompositionLevel, DecompositionTerm, DecompositionTermNonNative, DecompositionTermSlice, + DecompositionTermSliceNonNative, SignedDecomposerNonNative, }; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; @@ -132,6 +133,152 @@ pub(crate) fn decompose_one_level( res.wrapping_sub(carry << base_log) } +/// An iterator-like object that yields the terms of the signed decomposition of a tensor of values. +/// +/// # Note +/// +/// On each call to [`SliceSignedDecompositionIter::next_term`], this structure yields a new +/// [`DecompositionTermSlice`], backed by a `Vec` owned by the structure. This vec is mutated at +/// each call of the `next_term` method, and as such the term must be dropped before `next_term` is +/// called again. +/// +/// Such a pattern can not be implemented with iterators yet (without GATs), which is why this +/// iterator must be explicitly called. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct SliceSignedDecompositionIter +where + T: UnsignedInteger, +{ + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: T, + // The internal states of each decomposition + states: Vec, + // In order to avoid allocating a new Vec every time we yield a decomposition term, we store + // a Vec inside the structure and yield slices pointing to it. + outputs: Vec, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, +} + +impl SliceSignedDecompositionIter +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition iterator. + pub(crate) fn new( + input: &[T], + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ) -> Self { + let len = input.len(); + Self { + base_log: base_log.0, + level_count: level.0, + current_level: level.0, + mod_b_mask: (T::ONE << base_log.0) - T::ONE, + outputs: vec![T::ZERO; len], + states: input + .iter() + .map(|i| *i >> (T::BITS - base_log.0 * level.0)) + .collect(), + fresh: true, + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Yield the next term of the decomposition, if any. + /// + /// # Note + /// + /// Because this function returns a borrowed tensor, owned by the iterator, the term must be + /// dropped before `next_term` is called again. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// assert_eq!(term.as_slice()[0], 4294967295); + /// ``` + pub fn next_term(&mut self) -> Option> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We iterate over the elements of the outputs and decompose + for (output_i, state_i) in self.outputs.iter_mut().zip(self.states.iter_mut()) { + *output_i = decompose_one_level(self.base_log, state_i, self.mod_b_mask); + } + self.current_level -= 1; + // We return the term tensor. + Some(DecompositionTermSlice::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + &self.outputs, + )) + } +} + /// An iterator that yields the terms of the signed decomposition of an integer. /// /// # Warning @@ -277,6 +424,191 @@ where } } +/// An iterator-like object that yields the terms of the signed decomposition of a tensor of values. +/// +/// # Note +/// +/// On each call to [`SliceSignedDecompositionNonNativeIter::next_term`], this structure yields a +/// new +/// [`DecompositionTermSlice`], backed by a `Vec` owned by the structure. This vec is mutated at +/// each call of the `next_term` method, and as such the term must be dropped before `next_term` is +/// called again. +/// +/// Such a pattern can not be implemented with iterators yet (without GATs), which is why this +/// iterator must be explicitly called. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct SliceSignedDecompositionNonNativeIter +where + T: UnsignedInteger, +{ + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: T, + // Ciphertext modulus + ciphertext_modulus: CiphertextModulus, + // The internal states of each decomposition + states: Vec, + // In order to avoid allocating a new Vec every time we yield a decomposition term, we store + // a Vec inside the structure and yield slices pointing to it. + outputs: Vec, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, + // The signs of the input values, for the algorithm we use, returned values require adaption + // depending on the sign of the input + signs: Vec, +} + +impl SliceSignedDecompositionNonNativeIter +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition iterator. + pub(crate) fn new( + input: &[T], + input_signs: &[ValueSign], + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self { + base_log: base_log.0, + level_count: level.0, + current_level: level.0, + mod_b_mask: (T::ONE << base_log.0) - T::ONE, + ciphertext_modulus, + outputs: vec![T::ZERO; input.len()], + states: input + .iter() + .map(|i| { + *i >> (ciphertext_modulus.get_custom_modulus().ceil_ilog2() as usize + - base_log.0 * level.0) + }) + .collect(), + fresh: true, + signs: input_signs.to_vec(), + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Yield the next term of the decomposition, if any. + /// + /// # Note + /// + /// Because this function returns a borrowed tensor, owned by the iterator, the term must be + /// dropped before `next_term` is called again. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{ + /// DecompositionLevel, SignedDecomposerNonNative, + /// }; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// assert_eq!(term.as_slice()[0], u32::MAX); + /// ``` + pub fn next_term(&mut self) -> Option> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We iterate over the elements of the outputs and decompose + for ((output_i, state_i), sign_i) in self + .outputs + .iter_mut() + .zip(self.states.iter_mut()) + .zip(self.signs.iter()) + { + *output_i = decompose_one_level(self.base_log, state_i, self.mod_b_mask); + *output_i = match sign_i { + ValueSign::Positive => *output_i, + ValueSign::Negative => output_i.wrapping_neg(), + }; + } + self.current_level -= 1; + // We return the term tensor. + Some(DecompositionTermSliceNonNative::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + &self.outputs, + self.ciphertext_modulus, + )) + } +} + /// Specialized high performance implementation of a non native decomposer over a collection of /// elements, used notably in the PBS. pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> { diff --git a/tfhe/src/core_crypto/commons/math/decomposition/term.rs b/tfhe/src/core_crypto/commons/math/decomposition/term.rs index b5146859f0..e7e5443e5f 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/term.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/term.rs @@ -223,3 +223,284 @@ where DecompositionLevel(self.level) } } + +/// A tensor whose elements are the terms of the decomposition of another tensor. +/// +/// If we decompose each elements of a set of values $(\theta^{(a)})\_{a\in\mathbb{N}}$ as a set of +/// sums $(\sum\_{i=1}^l\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$, this represents a +/// set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTermSlice<'a, T> +where + T: UnsignedInteger, +{ + level: usize, + base_log: usize, + slice: &'a [T], +} + +impl<'a, T> DecompositionTermSlice<'a, T> +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + slice: &'a [T], + ) -> Self { + Self { + level: level.0, + base_log: base_log.0, + slice, + } + } + + /// Fills the output tensor with the terms turned to summands. + /// + /// If our term tensor represents a set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ of the + /// decomposition, this method fills the output tensor with a set of + /// $(\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// let mut output = vec![0u32; 2]; + /// term.fill_slice_with_recomposition_summand(&mut output); + /// assert!(output.iter().all(|&x| x == 1048576)); + /// ``` + pub fn fill_slice_with_recomposition_summand(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(dst, &value)| { + let shift: usize = ::BITS - self.base_log * self.level; + *dst = value << shift + }); + } + + pub(crate) fn update_slice_with_recomposition_summand_wrapping_addition( + &self, + output: &mut [T], + ) { + assert_eq!(self.slice.len(), output.len()); + let shift: usize = ::BITS - self.base_log * self.level; + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(out, &value)| { + *out = (*out).wrapping_add(value << shift); + }); + } + + /// Returns a tensor with the values of term. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.as_slice()[0], 1); + /// ``` + pub fn as_slice(&self) -> &'a [T] { + self.slice + } + + /// Returns the level of this decomposition term tensor. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} + +/// A tensor whose elements are the terms of the decomposition of another tensor. +/// +/// If we decompose each elements of a set of values $(\theta^{(a)})\_{a\in\mathbb{N}}$ as a set of +/// sums $(\sum\_{i=1}^l\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$, this represents a +/// set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTermSliceNonNative<'a, T> +where + T: UnsignedInteger, +{ + level: usize, + base_log: usize, + slice: &'a [T], + ciphertext_modulus: CiphertextModulus, +} + +impl<'a, T> DecompositionTermSliceNonNative<'a, T> +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + slice: &'a [T], + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self { + level: level.0, + base_log: base_log.0, + slice, + ciphertext_modulus, + } + } + + /// Fills the output tensor with the terms turned to summands. + /// + /// If our term tensor represents a set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ of the + /// decomposition, this method fills the output tensor with a set of + /// $(\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// let mut output = vec![0; 2]; + /// term.to_approximate_recomposition_summand(&mut output); + /// assert!(output.iter().all(|&x| x == 1048576)); + /// ``` + pub fn to_approximate_recomposition_summand(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + let ciphertext_modulus_bit_count: usize = modulus_as_t.ceil_ilog2().try_into().unwrap(); + let shift: usize = ciphertext_modulus_bit_count - self.base_log * self.level; + + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(dst, &value)| { + if value.into_signed() >= T::Signed::ZERO { + *dst = value << shift + } else { + *dst = modulus_as_t.wrapping_add(value << shift) + } + }); + } + + /// Compute the value of the term modulo the modulus given when building the + /// [`DecompositionTermSliceNonNative`] + pub fn modular_value(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + self.slice + .iter() + .zip(output.iter_mut()) + .for_each(|(&value, output)| { + if value.into_signed() >= T::Signed::ZERO { + *output = value + } else { + *output = modulus_as_t.wrapping_add(value) + } + }); + } + + pub(crate) fn update_slice_with_recomposition_summand_wrapping_addition( + &self, + output: &mut [T], + ) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + let ciphertext_modulus_bit_count: usize = modulus_as_t.ceil_ilog2().try_into().unwrap(); + let shift: usize = ciphertext_modulus_bit_count - self.base_log * self.level; + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(out, &value)| { + if value.into_signed() >= T::Signed::ZERO { + *out = (*out).wrapping_add_custom_mod(value << shift, modulus_as_t) + } else { + *out = (*out).wrapping_add_custom_mod( + modulus_as_t.wrapping_add(value << shift), + modulus_as_t, + ) + } + }); + } + + /// Returns a tensor with the values of term. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.as_slice()[0], 1); + /// ``` + pub fn as_slice(&self) -> &'a [T] { + self.slice + } + + /// Returns the level of this decomposition term tensor. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{ + /// DecompositionLevel, SignedDecomposerNonNative, + /// }; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} diff --git a/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs b/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs new file mode 100644 index 0000000000..cc80d9e9a6 --- /dev/null +++ b/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs @@ -0,0 +1,505 @@ +//! Module containing the definition of the [`GlweKeyswitchKey`]. + +use crate::conformance::ParameterSetConformant; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// A [`GLWE keyswitch key`](`GlweKeyswitchKey`). +/// +/// # Formal Definition +/// +/// ## Key Switching Key +/// +/// A key switching key is a vector of GLev ciphertexts (described on the bottom of +/// [`this page`](`crate::core_crypto::entities::GgswCiphertext#Glev-ciphertext`)). +/// It encrypts the coefficient of +/// the [`GLWE secret key`](`crate::core_crypto::entities::GlweSecretKey`) +/// $\vec{S}\_{\mathsf{in}}$ under the +/// [`GLWE secret key`](`crate::core_crypto::entities::GlweSecretKey`) +/// $\vec{S}\_{\mathsf{out}}$. +/// +/// $$\mathsf{KSK}\_{\vec{S}\_{\mathsf{in}}\rightarrow \vec{S}\_{\mathsf{out}}} = \left( +/// \overline{\mathsf{CT}\_0}, \cdots , \overline{\mathsf{CT}\_{k\_{\mathsf{in}}-1}}\right) +/// \subseteq R\_q^{(k\_{\mathsf{out}}+1)\cdot k\_{\mathsf{in}}\cdot \ell}$$ +/// +/// where $\vec{S}\_{\mathsf{in}} = \left( S\_0 , \cdots , S\_{\mathsf{in}-1} \right)$ and for all +/// $0\le i +where + C::Element: UnsignedInteger, +{ + data: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, +} + +impl> AsRef<[T]> for GlweKeyswitchKey { + fn as_ref(&self) -> &[T] { + self.data.as_ref() + } +} + +impl> AsMut<[T]> for GlweKeyswitchKey { + fn as_mut(&mut self) -> &mut [T] { + self.data.as_mut() + } +} + +/// Return the number of elements in an encryption of an input [`GlweSecretKey`] element for a +/// [`GlweKeyswitchKey`] given a [`DecompositionLevelCount`] and output [`GlweSize`] and +/// [`PolynomialSize`]. +pub fn glwe_keyswitch_key_input_key_element_encrypted_size( + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + // One ciphertext per level encrypted under the output key + decomp_level_count.0 * output_glwe_size.0 * poly_size.0 +} + +impl> GlweKeyswitchKey { + /// Create an [`GlweKeyswitchKey`] from an existing container. + /// + /// # Note + /// + /// This function only wraps a container in the appropriate type. If you want to generate an + /// [`GlweKeyswitchKey`] you need to call + /// [`crate::core_crypto::algorithms::generate_glwe_keyswitch_key`] using this key as output. + /// + /// This docstring exhibits [`GlweKeyswitchKey`] primitives usage. + /// + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// + /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + /// // computations + /// // Define parameters for LweKeyswitchKey creation + /// let input_glwe_dimension = GlweDimension(1); + /// let output_glwe_dimension = GlweDimension(2); + /// let poly_size = PolynomialSize(1024); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let ciphertext_modulus = CiphertextModulus::new_native(); + /// + /// // Create a new LweKeyswitchKey + /// let glwe_ksk = GlweKeyswitchKey::new( + /// 0u64, + /// decomp_base_log, + /// decomp_level_count, + /// input_glwe_dimension, + /// output_glwe_dimension, + /// poly_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(glwe_ksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(glwe_ksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(glwe_ksk.input_key_glwe_dimension(), input_glwe_dimension); + /// assert_eq!(glwe_ksk.output_key_glwe_dimension(), output_glwe_dimension); + /// assert_eq!(glwe_ksk.polynomial_size(), poly_size); + /// assert_eq!( + /// glwe_ksk.output_glwe_size(), + /// output_glwe_dimension.to_glwe_size() + /// ); + /// assert_eq!(glwe_ksk.ciphertext_modulus(), ciphertext_modulus); + /// + /// // Demonstrate how to recover the allocated container + /// let underlying_container: Vec = glwe_ksk.into_container(); + /// + /// // Recreate a keyswitch key using from_container + /// let glwe_ksk = GlweKeyswitchKey::from_container( + /// underlying_container, + /// decomp_base_log, + /// decomp_level_count, + /// output_glwe_dimension.to_glwe_size(), + /// poly_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(glwe_ksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(glwe_ksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(glwe_ksk.input_key_glwe_dimension(), input_glwe_dimension); + /// assert_eq!(glwe_ksk.output_key_glwe_dimension(), output_glwe_dimension); + /// assert_eq!( + /// glwe_ksk.output_glwe_size(), + /// output_glwe_dimension.to_glwe_size() + /// ); + /// assert_eq!(glwe_ksk.ciphertext_modulus(), ciphertext_modulus); + /// ``` + pub fn from_container( + container: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert!( + container.container_len() > 0, + "Got an empty container to create a GlweKeyswitchKey" + ); + assert!( + container.container_len() % (decomp_level_count.0 * output_glwe_size.0 * poly_size.0) + == 0, + "The provided container length is not valid. \ + It needs to be dividable by decomp_level_count * output_glwe_size * output_poly_size: {}. \ + Got container length: {} and decomp_level_count: {decomp_level_count:?}, \ + output_glwe_size: {output_glwe_size:?}, poly_size: {poly_size:?}.", + decomp_level_count.0 * output_glwe_size.0 * poly_size.0, + container.container_len() + ); + + Self { + data: container, + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + } + } + + /// Return the [`DecompositionBaseLog`] of the [`LweKeyswitchKey`]. + /// + /// See [`LweKeyswitchKey::from_container`] for usage. + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Return the [`DecompositionLevelCount`] of the [`LweKeyswitchKey`]. + /// + /// See [`LweKeyswitchKey::from_container`] for usage. + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomp_level_count + } + + /// Return the input [`GlweDimension`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn input_key_glwe_dimension(&self) -> GlweDimension { + GlweDimension(self.data.container_len() / self.input_key_element_encrypted_size()) + } + + /// Return the input [`PolynomialSize`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Return the output [`GlweDimension`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn output_key_glwe_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Return the output [`GlweSize`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn output_glwe_size(&self) -> GlweSize { + self.output_glwe_size + } + + /// Return the number of elements in an encryption of an input [`GlweSecretKey`] element of the + /// current [`GlweKeyswitchKey`]. + pub fn input_key_element_encrypted_size(&self) -> usize { + glwe_keyswitch_key_input_key_element_encrypted_size( + self.decomp_level_count, + self.output_glwe_size, + self.poly_size, + ) + } + + /// Return a view of the [`GlweKeyswitchKey`]. This is useful if an algorithm takes a view by + /// value. + pub fn as_view(&self) -> GlweKeyswitchKey<&'_ [Scalar]> { + GlweKeyswitchKey::from_container( + self.as_ref(), + self.decomp_base_log, + self.decomp_level_count, + self.output_glwe_size, + self.poly_size, + self.ciphertext_modulus, + ) + } + + /// Consume the entity and return its underlying container. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn into_container(self) -> C { + self.data + } + + pub fn as_glwe_ciphertext_list(&self) -> GlweCiphertextListView<'_, Scalar> { + GlweCiphertextListView::from_container( + self.as_ref(), + self.output_glwe_size(), + self.polynomial_size(), + self.ciphertext_modulus(), + ) + } + + /// Return the [`CiphertextModulus`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn ciphertext_modulus(&self) -> CiphertextModulus { + self.ciphertext_modulus + } +} + +impl> GlweKeyswitchKey { + /// Mutable variant of [`GlweKeyswitchKey::as_view`]. + pub fn as_mut_view(&mut self) -> GlweKeyswitchKey<&'_ mut [Scalar]> { + let decomp_base_log = self.decomp_base_log; + let decomp_level_count = self.decomp_level_count; + let output_glwe_size = self.output_glwe_size; + let poly_size = self.poly_size; + let ciphertext_modulus = self.ciphertext_modulus; + GlweKeyswitchKey::from_container( + self.as_mut(), + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + ) + } + + pub fn as_mut_glwe_ciphertext_list(&mut self) -> GlweCiphertextListMutView<'_, Scalar> { + let output_glwe_size = self.output_glwe_size(); + let poly_size = self.polynomial_size(); + let ciphertext_modulus = self.ciphertext_modulus(); + GlweCiphertextListMutView::from_container( + self.as_mut(), + output_glwe_size, + poly_size, + ciphertext_modulus, + ) + } +} + +/// A [`GlweKeyswitchKey`] owning the memory for its own storage. +pub type GlweKeyswitchKeyOwned = GlweKeyswitchKey>; +/// A [`GlweKeyswitchKey`] immutably borrowing memory for its own storage. +pub type GlweKeyswitchKeyView<'data, Scalar> = GlweKeyswitchKey<&'data [Scalar]>; +/// A [`GlweKeyswitchKey`] mutably borrowing memory for its own storage. +pub type GlweKeyswitchKeyMutView<'data, Scalar> = GlweKeyswitchKey<&'data mut [Scalar]>; + +impl GlweKeyswitchKeyOwned { + /// Allocate memory and create a new owned [`GlweKeyswitchKey`]. + /// + /// # Note + /// + /// This function allocates a vector of the appropriate size and wraps it in the appropriate + /// type. If you want to generate an [`GlweKeyswitchKey`] you need to call + /// [`crate::core_crypto::algorithms::generate_glwe_keyswitch_key`] using this key as output. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn new( + fill_with: Scalar, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_key_glwe_dimension: GlweDimension, + output_key_glwe_dimension: GlweDimension, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self::from_container( + vec![ + fill_with; + input_key_glwe_dimension.0 + * glwe_keyswitch_key_input_key_element_encrypted_size( + decomp_level_count, + output_key_glwe_dimension.to_glwe_size(), + poly_size, + ) + ], + decomp_base_log, + decomp_level_count, + output_key_glwe_dimension.to_glwe_size(), + poly_size, + ciphertext_modulus, + ) + } +} + +#[derive(Clone, Copy)] +pub struct GlweKeyswitchKeyCreationMetadata { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub output_glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> CreateFrom + for GlweKeyswitchKey +{ + type Metadata = GlweKeyswitchKeyCreationMetadata; + + #[inline] + fn create_from(from: C, meta: Self::Metadata) -> Self { + let GlweKeyswitchKeyCreationMetadata { + decomp_base_log, + decomp_level_count, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + } = meta; + Self::from_container( + from, + decomp_base_log, + decomp_level_count, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +impl> ContiguousEntityContainer + for GlweKeyswitchKey +{ + type Element = C::Element; + + type EntityViewMetadata = GlweCiphertextListCreationMetadata; + + type EntityView<'this> + = GlweCiphertextListView<'this, Self::Element> + where + Self: 'this; + + type SelfViewMetadata = GlweKeyswitchKeyCreationMetadata; + + type SelfView<'this> + = GlweKeyswitchKeyView<'this, Self::Element> + where + Self: 'this; + + fn get_entity_view_creation_metadata(&self) -> Self::EntityViewMetadata { + GlweCiphertextListCreationMetadata { + glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } + + fn get_entity_view_pod_size(&self) -> usize { + self.input_key_element_encrypted_size() + } + + fn get_self_view_creation_metadata(&self) -> Self::SelfViewMetadata { + GlweKeyswitchKeyCreationMetadata { + decomp_base_log: self.decomposition_base_log(), + decomp_level_count: self.decomposition_level_count(), + output_glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } +} + +impl> ContiguousEntityContainerMut + for GlweKeyswitchKey +{ + type EntityMutView<'this> + = GlweCiphertextListMutView<'this, Self::Element> + where + Self: 'this; + + type SelfMutView<'this> + = GlweKeyswitchKeyMutView<'this, Self::Element> + where + Self: 'this; +} + +pub struct GlweKeyswitchKeyConformanceParams { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub output_glwe_size: GlweSize, + pub input_glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> ParameterSetConformant for GlweKeyswitchKey { + type ParameterSet = GlweKeyswitchKeyConformanceParams; + + fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { + let Self { + data, + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + } = self; + + *ciphertext_modulus == parameter_set.ciphertext_modulus + && data.container_len() + == parameter_set.input_glwe_dimension.0 + * glwe_keyswitch_key_input_key_element_encrypted_size( + parameter_set.decomp_level_count, + parameter_set.output_glwe_size, + parameter_set.polynomial_size, + ) + && *decomp_base_log == parameter_set.decomp_base_log + && *decomp_level_count == parameter_set.decomp_level_count + && *output_glwe_size == parameter_set.output_glwe_size + && *poly_size == parameter_set.polynomial_size + } +} diff --git a/tfhe/src/core_crypto/entities/glwe_relinearization_key.rs b/tfhe/src/core_crypto/entities/glwe_relinearization_key.rs new file mode 100644 index 0000000000..c9b982d02b --- /dev/null +++ b/tfhe/src/core_crypto/entities/glwe_relinearization_key.rs @@ -0,0 +1,422 @@ +//! Module containing the definition of the [`GlweRelinearizationKey`]. + +use crate::conformance::ParameterSetConformant; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// A [`GLWE relinearization key`](`GlweRelinearizationKey`). +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct GlweRelinearizationKey +where + C::Element: UnsignedInteger, +{ + data: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, +} + +impl> AsRef<[T]> for GlweRelinearizationKey { + fn as_ref(&self) -> &[T] { + self.data.as_ref() + } +} + +impl> AsMut<[T]> for GlweRelinearizationKey { + fn as_mut(&mut self) -> &mut [T] { + self.data.as_mut() + } +} + +/// Return the number of elements in an encryption of an input [`GlweSecretKey`] element for a +/// [`GlweRelinearizationKey`] given a [`DecompositionLevelCount`], [`GlweSize`] and +/// [`PolynomialSize`]. +pub fn glwe_relinearization_key_input_key_element_encrypted_size( + decomp_level_count: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, +) -> usize { + // One ciphertext per level encrypted under the output key + decomp_level_count.0 * glwe_size.0 * polynomial_size.0 +} + +/// Return the number of elements in a [`GlweRelinearizationKey`] given a +/// [`DecompositionLevelCount`], [`GlweSize`], and [`PolynomialSize`]. +pub fn glwe_relinearization_key_size( + decomp_level_count: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, +) -> usize { + (glwe_size.to_glwe_dimension().0 * (glwe_size.to_glwe_dimension().0 + 1)) / 2 + * glwe_relinearization_key_input_key_element_encrypted_size( + decomp_level_count, + glwe_size, + polynomial_size, + ) +} + +impl> GlweRelinearizationKey { + /// Create a [`GlweRelinearizationKey`] from an existing container. + /// + /// # Note + /// + /// This function only wraps a container in the appropriate type. If you want to generate an + /// [`GlweRelinearizationKey`] you need to use + /// [`crate::core_crypto::algorithms::generate_glwe_relinearization_key`] + /// using this key as output. + /// + /// This docstring exhibits [`GlweRelinearizationKey`] primitives usage. + /// + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// + /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + /// // computations + /// // Define parameters for GlweRelinearizationKey creation + /// let glwe_size = GlweSize(3); + /// let polynomial_size = PolynomialSize(1024); + /// let decomp_base_log = DecompositionBaseLog(8); + /// let decomp_level_count = DecompositionLevelCount(3); + /// let ciphertext_modulus = CiphertextModulus::new_native(); + /// + /// // Create a new GlweRelinearizationKey + /// let relin_key = GlweRelinearizationKey::new( + /// 0u64, + /// decomp_base_log, + /// decomp_level_count, + /// glwe_size, + /// polynomial_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(relin_key.glwe_dimension(), glwe_size.to_glwe_dimension()); + /// assert_eq!(relin_key.glwe_size(), glwe_size); + /// assert_eq!(relin_key.polynomial_size(), polynomial_size); + /// assert_eq!(relin_key.decomposition_base_log(), decomp_base_log); + /// assert_eq!(relin_key.decomposition_level_count(), decomp_level_count); + /// assert_eq!(relin_key.ciphertext_modulus(), ciphertext_modulus); + /// + /// // Demonstrate how to recover the allocated container + /// let underlying_container: Vec = relin_key.into_container(); + /// + /// // Recreate a key using from_container + /// let relin_key = GlweRelinearizationKey::from_container( + /// underlying_container, + /// decomp_base_log, + /// decomp_level_count, + /// glwe_size, + /// polynomial_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(relin_key.glwe_dimension(), glwe_size.to_glwe_dimension()); + /// assert_eq!(relin_key.glwe_size(), glwe_size); + /// assert_eq!(relin_key.polynomial_size(), polynomial_size); + /// assert_eq!(relin_key.decomposition_base_log(), decomp_base_log); + /// assert_eq!(relin_key.decomposition_level_count(), decomp_level_count); + /// assert_eq!(relin_key.ciphertext_modulus(), ciphertext_modulus); + /// ``` + pub fn from_container( + container: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert!( + container.container_len() > 0, + "Got an empty container to create a GlweRelinearizationKey" + ); + assert!( + container.container_len() + % glwe_relinearization_key_input_key_element_encrypted_size( + decomp_level_count, + glwe_size, + polynomial_size + ) + == 0, + "The provided container length is not valid. \ + It needs to be divisible by decomp_level_count * glwe_size * polynomial_size:\ + {}. Got container length: {} and decomp_level_count: {decomp_level_count:?}, \ + glwe_size: {glwe_size:?}, polynomial_size: {polynomial_size:?}.", + glwe_relinearization_key_input_key_element_encrypted_size( + decomp_level_count, + glwe_size, + polynomial_size + ), + container.container_len() + ); + + Self { + data: container, + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + } + } + + /// Return the [`GlweDimension`] of the [`GlweRelinearizationKey`]. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn glwe_dimension(&self) -> GlweDimension { + self.glwe_size.to_glwe_dimension() + } + + /// Return the [`GlweSize`] of the [`GlweRelinearizationKey`]. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Return the output [`PolynomialSize`] of the [`GlweRelinearizationKey`]. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + /// Return the [`DecompositionLevelCount`] of the [`GlweRelinearizationKey`]. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomp_level_count + } + + /// Return the [`DecompositionBaseLog`] of the [`GlweRelinearizationKey`]. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Return the number of elements in an encryption of an input [`GlweSecretKey`] element of the + /// current [`GlweRelinearizationKey`]. + pub fn input_key_element_encrypted_size(&self) -> usize { + glwe_relinearization_key_input_key_element_encrypted_size( + self.decomp_level_count, + self.glwe_size, + self.polynomial_size, + ) + } + + /// Return a view of the [`GlweRelinearizationKey`]. This is useful if an + /// algorithm takes a view by value. + pub fn as_view(&self) -> GlweRelinearizationKey<&'_ [Scalar]> { + GlweRelinearizationKey::from_container( + self.as_ref(), + self.decomp_base_log, + self.decomp_level_count, + self.glwe_size, + self.polynomial_size, + self.ciphertext_modulus, + ) + } + + /// Consume the entity and return its underlying container. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn into_container(self) -> C { + self.data + } + + /// Return the [`CiphertextModulus`] of the [`GlweRelinearizationKey`] + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn ciphertext_modulus(&self) -> CiphertextModulus { + self.ciphertext_modulus + } +} + +impl> GlweRelinearizationKey { + /// Mutable variant of [`GlweRelinearizationKey::as_view`]. + pub fn as_mut_view(&mut self) -> GlweRelinearizationKey<&'_ mut [Scalar]> { + let decomp_base_log = self.decomp_base_log; + let decomp_level_count = self.decomp_level_count; + let glwe_size = self.glwe_size; + let polynomial_size = self.polynomial_size; + let ciphertext_modulus = self.ciphertext_modulus; + + GlweRelinearizationKey::from_container( + self.as_mut(), + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +/// A [`GlweRelinearizationKey`] owning the memory for its own storage. +pub type GlweRelinearizationKeyOwned = GlweRelinearizationKey>; +/// A [`GlweRelinearizationKey`] immutably borrowing memory for its own storage. +pub type GlweRelinearizationKeyView<'data, Scalar> = GlweRelinearizationKey<&'data [Scalar]>; +/// A [`GlweRelinearizationKey`] mutably borrowing memory for its own storage. +pub type GlweRelinearizationKeyMutView<'data, Scalar> = GlweRelinearizationKey<&'data mut [Scalar]>; + +impl GlweRelinearizationKeyOwned { + /// Create a new [`GlweRelinearizationKey`]. + /// + /// # Note + /// + /// This function allocates a vector of the appropriate size and wraps it in the appropriate + /// type. If you want to generate an [`GlweRelinearizationKey`] you need to use + /// [`crate::core_crypto::algorithms::generate_glwe_relinearization_key`] + /// using this key as output. + /// + /// See [`GlweRelinearizationKey::from_container`] for usage. + pub fn new( + fill_with: Scalar, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self::from_container( + vec![ + fill_with; + glwe_relinearization_key_size(decomp_level_count, glwe_size, polynomial_size) + ], + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +/// Metadata used in the [`CreateFrom`] implementation to create +/// [`GlweRelinearizationKey`] entities. +#[derive(Clone, Copy)] +pub struct GlweRelinearizationKeyCreationMetadata { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> CreateFrom + for GlweRelinearizationKey +{ + type Metadata = GlweRelinearizationKeyCreationMetadata; + + #[inline] + fn create_from(from: C, meta: Self::Metadata) -> Self { + let GlweRelinearizationKeyCreationMetadata { + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + } = meta; + Self::from_container( + from, + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +impl> ContiguousEntityContainer + for GlweRelinearizationKey +{ + type Element = C::Element; + + type EntityViewMetadata = GlweCiphertextListCreationMetadata; + + type EntityView<'this> + = GlweCiphertextListView<'this, Self::Element> + where + Self: 'this; + + type SelfViewMetadata = GlweRelinearizationKeyCreationMetadata; + + type SelfView<'this> + = GlweRelinearizationKeyView<'this, Self::Element> + where + Self: 'this; + + fn get_entity_view_creation_metadata(&self) -> Self::EntityViewMetadata { + GlweCiphertextListCreationMetadata { + glwe_size: self.glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } + + fn get_entity_view_pod_size(&self) -> usize { + self.input_key_element_encrypted_size() + } + + fn get_self_view_creation_metadata(&self) -> Self::SelfViewMetadata { + GlweRelinearizationKeyCreationMetadata { + decomp_base_log: self.decomposition_base_log(), + decomp_level_count: self.decomposition_level_count(), + glwe_size: self.glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } +} + +impl> ContiguousEntityContainerMut + for GlweRelinearizationKey +{ + type EntityMutView<'this> + = GlweCiphertextListMutView<'this, Self::Element> + where + Self: 'this; + + type SelfMutView<'this> + = GlweRelinearizationKeyMutView<'this, Self::Element> + where + Self: 'this; +} + +pub struct RelinearizationKeyConformanceParmas { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> ParameterSetConformant for GlweRelinearizationKey { + type ParameterSet = RelinearizationKeyConformanceParmas; + + fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { + let Self { + data, + decomp_base_log, + decomp_level_count, + glwe_size, + polynomial_size, + ciphertext_modulus, + } = self; + + *ciphertext_modulus == parameter_set.ciphertext_modulus + && data.container_len() + == glwe_relinearization_key_size( + parameter_set.decomp_level_count, + parameter_set.glwe_size, + parameter_set.polynomial_size, + ) + && *decomp_base_log == parameter_set.decomp_base_log + && *decomp_level_count == parameter_set.decomp_level_count + && *glwe_size == parameter_set.glwe_size + && *polynomial_size == parameter_set.polynomial_size + } +} diff --git a/tfhe/src/core_crypto/entities/lwe_trace_packing_keyswitch_key.rs b/tfhe/src/core_crypto/entities/lwe_trace_packing_keyswitch_key.rs new file mode 100644 index 0000000000..2ecdf50127 --- /dev/null +++ b/tfhe/src/core_crypto/entities/lwe_trace_packing_keyswitch_key.rs @@ -0,0 +1,558 @@ +//! Module containing the definition of the [`LweTracePackingKeyswitchKey`]. + +use crate::conformance::ParameterSetConformant; +use crate::core_crypto::commons::generators::EncryptionRandomGeneratorForkConfig; +use crate::core_crypto::commons::math::random::{Distribution, RandomGenerable}; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// An [`LWE trace packing keyswitch key`](`LweTracePackingKeyswitchKey`). +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct LweTracePackingKeyswitchKey +where + C::Element: UnsignedInteger, +{ + data: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_lwe_size: LweSize, + output_glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, +} + +impl> AsRef<[T]> for LweTracePackingKeyswitchKey { + fn as_ref(&self) -> &[T] { + self.data.as_ref() + } +} + +impl> AsMut<[T]> + for LweTracePackingKeyswitchKey +{ + fn as_mut(&mut self) -> &mut [T] { + self.data.as_mut() + } +} + +/// Return the number of elements in an encryption of an input [`GlweSecretKey`] element for a +/// [`LweTracePackingKeyswitchKey`] given a [`DecompositionLevelCount`] and output +/// [`GlweSize`] and [`PolynomialSize`]. +pub fn lwe_tpksk_input_key_element_encrypted_size( + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + polynomial_size: PolynomialSize, +) -> usize { + // One ciphertext per level encrypted under the output key + decomp_level_count.0 * output_glwe_size.0 * polynomial_size.0 +} + +/// Return the number of elements in an [`LweTracePackingKeyswitchKey`] given a +/// [`DecompositionLevelCount`], output [`GlweSize`], and output [`PolynomialSize`]. +pub fn lwe_tpksk_size( + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + polynomial_size: PolynomialSize, +) -> usize { + output_glwe_size.to_glwe_dimension().0 + * polynomial_size.log2().0 + * lwe_tpksk_input_key_element_encrypted_size( + decomp_level_count, + output_glwe_size, + polynomial_size, + ) +} + +pub fn tpksk_encryption_mask_sample_count( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomp_level_count: DecompositionLevelCount, +) -> EncryptionMaskSampleCount { + decomp_level_count.0 + * glwe_size.to_glwe_dimension().0 + * glwe_ciphertext_encryption_mask_sample_count( + glwe_size.to_glwe_dimension(), + polynomial_size, + ) +} + +pub fn tpksk_encryption_noise_sample_count( + polynomial_size: PolynomialSize, + decomp_level_count: DecompositionLevelCount, +) -> EncryptionNoiseSampleCount { + decomp_level_count.0 * glwe_ciphertext_encryption_noise_sample_count(polynomial_size) +} + +pub fn lwe_trace_packing_keyswitch_key_encryption_fork_config< + Scalar, + MaskDistribution, + NoiseDistribution, +>( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomposition_level_count: DecompositionLevelCount, + mask_distribution: MaskDistribution, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, +) -> EncryptionRandomGeneratorForkConfig +where + Scalar: UnsignedInteger + + RandomGenerable + + RandomGenerable, + MaskDistribution: Distribution, + NoiseDistribution: Distribution, +{ + let tpksk_mask_sample_count = + tpksk_encryption_mask_sample_count(glwe_size, polynomial_size, decomposition_level_count); + let tpksk_noise_sample_count = + tpksk_encryption_noise_sample_count(polynomial_size, decomposition_level_count); + + let modulus = ciphertext_modulus.get_custom_modulus_as_optional_scalar(); + + EncryptionRandomGeneratorForkConfig::new( + polynomial_size.log2().0, + tpksk_mask_sample_count, + mask_distribution, + tpksk_noise_sample_count, + noise_distribution, + modulus, + ) +} + +impl> LweTracePackingKeyswitchKey { + /// Create an [`LweTracePackingKeyswitchKey`] from an existing container. + /// + /// # Note + /// + /// This function only wraps a container in the appropriate type. If you want to generate an + /// [`LweTracePackingKeyswitchKey`] you need to use + /// [`crate::core_crypto::algorithms::generate_lwe_trace_packing_keyswitch_key`] + /// using this key as output. + /// + /// This docstring exhibits [`LweTracePackingKeyswitchKey`] primitives usage. + /// + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// + /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + /// // computations + /// // Define parameters for LweTracePackingKeyswitchKey creation + /// let lwe_size = LweSize(1001); + /// let glwe_size = GlweSize(2); + /// let polynomial_size = PolynomialSize(1024); + /// let decomp_base_log = DecompositionBaseLog(8); + /// let decomp_level_count = DecompositionLevelCount(3); + /// let ciphertext_modulus = CiphertextModulus::new_native(); + /// + /// // Create a new LweTracePackingKeyswitchKey + /// let tpksk = LweTracePackingKeyswitchKey::new( + /// 0u64, + /// decomp_base_log, + /// decomp_level_count, + /// lwe_size, + /// glwe_size, + /// polynomial_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!( + /// tpksk.output_glwe_key_dimension(), + /// glwe_size.to_glwe_dimension() + /// ); + /// assert_eq!(tpksk.output_glwe_size(), glwe_size); + /// assert_eq!(tpksk.polynomial_size(), polynomial_size); + /// assert_eq!(tpksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(tpksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(tpksk.ciphertext_modulus(), ciphertext_modulus); + /// + /// // Demonstrate how to recover the allocated container + /// let underlying_container: Vec = tpksk.into_container(); + /// + /// // Recreate a key using from_container + /// let tpksk = LweTracePackingKeyswitchKey::from_container( + /// underlying_container, + /// decomp_base_log, + /// decomp_level_count, + /// lwe_size, + /// glwe_size, + /// polynomial_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!( + /// tpksk.output_glwe_key_dimension(), + /// glwe_size.to_glwe_dimension() + /// ); + /// assert_eq!(tpksk.input_lwe_size(), lwe_size); + /// assert_eq!(tpksk.output_glwe_size(), glwe_size); + /// assert_eq!(tpksk.polynomial_size(), polynomial_size); + /// assert_eq!(tpksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(tpksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(tpksk.ciphertext_modulus(), ciphertext_modulus); + /// ``` + pub fn from_container( + container: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_lwe_size: LweSize, + output_glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert!( + container.container_len() > 0, + "Got an empty container to create a LweTracePackingKeyswitchKey" + ); + assert!( + container.container_len() + % lwe_tpksk_input_key_element_encrypted_size( + decomp_level_count, + output_glwe_size, + polynomial_size + ) + == 0, + "The provided container length is not valid. \ + It needs to be divisible by decomp_level_count * output_glwe_size * polynomial_size:\ + {}. Got container length: {} and decomp_level_count: {decomp_level_count:?}, \ + output_glwe_size: {output_glwe_size:?}, polynomial_size: \ + {polynomial_size:?}.", + lwe_tpksk_input_key_element_encrypted_size( + decomp_level_count, + output_glwe_size, + polynomial_size + ), + container.container_len() + ); + + Self { + data: container, + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + } + } + + /// Return the output key [`GlweDimension`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn output_glwe_key_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Return the output [`GlweSize`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn output_glwe_size(&self) -> GlweSize { + self.output_glwe_size + } + + /// Return the output [`PolynomialSize`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + /// Return the input [`LweSize`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn input_lwe_size(&self) -> LweSize { + self.input_lwe_size + } + + /// Return the [`DecompositionLevelCount`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomp_level_count + } + + /// Return the [`DecompositionBaseLog`] of the [`LweTracePackingKeyswitchKey`]. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Return the number of elements in an encryption of an input [`LweSecretKey`] element of the + /// current [`LweTracePackingKeyswitchKey`]. + pub fn input_key_element_encrypted_size(&self) -> usize { + lwe_tpksk_input_key_element_encrypted_size( + self.decomp_level_count, + self.output_glwe_size, + self.polynomial_size, + ) + } + + /// Return a view of the [`LweTracePackingKeyswitchKey`]. This is useful if an + /// algorithm takes a view by value. + pub fn as_view(&self) -> LweTracePackingKeyswitchKey<&'_ [Scalar]> { + LweTracePackingKeyswitchKey::from_container( + self.as_ref(), + self.decomp_base_log, + self.decomp_level_count, + self.input_lwe_size, + self.output_glwe_size, + self.polynomial_size, + self.ciphertext_modulus, + ) + } + + /// Consume the entity and return its underlying container. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn into_container(self) -> C { + self.data + } + + pub fn as_glwe_ciphertext_list(&self) -> GlweCiphertextListView<'_, Scalar> { + GlweCiphertextListView::from_container( + self.as_ref(), + self.output_glwe_size(), + self.polynomial_size(), + self.ciphertext_modulus(), + ) + } + + /// Return the [`CiphertextModulus`] of the [`LweTracePackingKeyswitchKey`] + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn ciphertext_modulus(&self) -> CiphertextModulus { + self.ciphertext_modulus + } + + pub fn encryption_fork_config( + &self, + mask_distribution: MaskDistribution, + noise_distribution: NoiseDistribution, + ) -> EncryptionRandomGeneratorForkConfig + where + MaskDistribution: Distribution, + NoiseDistribution: Distribution, + Scalar: RandomGenerable + + RandomGenerable, + { + lwe_trace_packing_keyswitch_key_encryption_fork_config( + self.output_glwe_size(), + self.polynomial_size(), + self.decomposition_level_count(), + mask_distribution, + noise_distribution, + self.ciphertext_modulus(), + ) + } +} + +impl> LweTracePackingKeyswitchKey { + /// Mutable variant of [`LweTracePackingKeyswitchKey::as_view`]. + pub fn as_mut_view(&mut self) -> LweTracePackingKeyswitchKey<&'_ mut [Scalar]> { + let decomp_base_log = self.decomp_base_log; + let decomp_level_count = self.decomp_level_count; + let input_lwe_size = self.input_lwe_size; + let output_glwe_size = self.output_glwe_size; + let polynomial_size = self.polynomial_size; + let ciphertext_modulus = self.ciphertext_modulus; + LweTracePackingKeyswitchKey::from_container( + self.as_mut(), + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } + + pub fn as_mut_glwe_ciphertext_list(&mut self) -> GlweCiphertextListMutView<'_, Scalar> { + let output_glwe_size = self.output_glwe_size(); + let output_polynomial_size = self.polynomial_size(); + let ciphertext_modulus = self.ciphertext_modulus(); + GlweCiphertextListMutView::from_container( + self.as_mut(), + output_glwe_size, + output_polynomial_size, + ciphertext_modulus, + ) + } +} + +/// An [`LweTracePackingKeyswitchKey`] owning the memory for its own storage. +pub type LweTracePackingKeyswitchKeyOwned = LweTracePackingKeyswitchKey>; +/// An [`LweTracePackingKeyswitchKey`] immutably borrowing memory for its own storage. +pub type LweTracePackingKeyswitchKeyView<'data, Scalar> = + LweTracePackingKeyswitchKey<&'data [Scalar]>; +/// An [`LweTracePackingKeyswitchKey`] mutably borrowing memory for its own storage. +pub type LweTracePackingKeyswitchKeyMutView<'data, Scalar> = + LweTracePackingKeyswitchKey<&'data mut [Scalar]>; + +impl LweTracePackingKeyswitchKeyOwned { + /// Create an [`LweTracePackingKeyswitchKey`] from an existing container. + /// + /// # Note + /// + /// This function allocates a vector of the appropriate size and wraps it in the appropriate + /// type. If you want to generate an [`LweTracePackingKeyswitchKey`] you need to use + /// [`crate::core_crypto::algorithms::generate_lwe_trace_packing_keyswitch_key`] + /// using this key as output. + /// + /// See [`LweTracePackingKeyswitchKey::from_container`] for usage. + pub fn new( + fill_with: Scalar, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_lwe_size: LweSize, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self::from_container( + vec![fill_with; lwe_tpksk_size(decomp_level_count, output_glwe_size, poly_size)], + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + poly_size, + ciphertext_modulus, + ) + } +} + +/// Metadata used in the [`CreateFrom`] implementation to create +/// [`LweTracePackingKeyswitchKey`] entities. +#[derive(Clone, Copy)] +pub struct LweTracePackingKeyswitchKeyCreationMetadata { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub input_lwe_size: LweSize, + pub output_glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> CreateFrom + for LweTracePackingKeyswitchKey +{ + type Metadata = LweTracePackingKeyswitchKeyCreationMetadata; + + #[inline] + fn create_from(from: C, meta: Self::Metadata) -> Self { + let LweTracePackingKeyswitchKeyCreationMetadata { + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + } = meta; + Self::from_container( + from, + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +impl> ContiguousEntityContainer + for LweTracePackingKeyswitchKey +{ + type Element = C::Element; + + type EntityViewMetadata = GlweCiphertextListCreationMetadata; + + type EntityView<'this> + = GlweCiphertextListView<'this, Self::Element> + where + Self: 'this; + + type SelfViewMetadata = LweTracePackingKeyswitchKeyCreationMetadata; + + type SelfView<'this> + = LweTracePackingKeyswitchKeyView<'this, Self::Element> + where + Self: 'this; + + fn get_entity_view_creation_metadata(&self) -> Self::EntityViewMetadata { + GlweCiphertextListCreationMetadata { + glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } + + fn get_entity_view_pod_size(&self) -> usize { + self.input_key_element_encrypted_size() * self.output_glwe_size.to_glwe_dimension().0 + } + + /// Unimplemented for [`LweTracePackingKeyswitchKey`]. At the moment it does not + /// make sense to return "sub" packing keyswitch keys. + fn get_self_view_creation_metadata(&self) -> Self::SelfViewMetadata { + LweTracePackingKeyswitchKeyCreationMetadata { + decomp_base_log: self.decomposition_base_log(), + decomp_level_count: self.decomposition_level_count(), + input_lwe_size: self.input_lwe_size(), + output_glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } +} + +impl> ContiguousEntityContainerMut + for LweTracePackingKeyswitchKey +{ + type EntityMutView<'this> + = GlweCiphertextListMutView<'this, Self::Element> + where + Self: 'this; + + type SelfMutView<'this> + = LweTracePackingKeyswitchKeyMutView<'this, Self::Element> + where + Self: 'this; +} + +pub struct LweTracePackingKeyswitchKeyConformanceParams { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub input_lwe_size: LweSize, + pub output_glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> ParameterSetConformant for LweTracePackingKeyswitchKey { + type ParameterSet = LweTracePackingKeyswitchKeyConformanceParams; + + fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { + let Self { + data, + decomp_base_log, + decomp_level_count, + input_lwe_size, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + } = self; + + *ciphertext_modulus == parameter_set.ciphertext_modulus + && data.container_len() + == lwe_tpksk_size( + parameter_set.decomp_level_count, + parameter_set.output_glwe_size, + parameter_set.polynomial_size, + ) + && *decomp_base_log == parameter_set.decomp_base_log + && *decomp_level_count == parameter_set.decomp_level_count + && *input_lwe_size == parameter_set.input_lwe_size + && *output_glwe_size == parameter_set.output_glwe_size + && *polynomial_size == parameter_set.polynomial_size + } +} diff --git a/tfhe/src/core_crypto/entities/mod.rs b/tfhe/src/core_crypto/entities/mod.rs index f951c7b1c9..a92276fec2 100644 --- a/tfhe/src/core_crypto/entities/mod.rs +++ b/tfhe/src/core_crypto/entities/mod.rs @@ -11,6 +11,8 @@ pub mod ggsw_ciphertext; pub mod ggsw_ciphertext_list; pub mod glwe_ciphertext; pub mod glwe_ciphertext_list; +pub mod glwe_keyswitch_key; +pub mod glwe_relinearization_key; pub mod glwe_secret_key; pub mod gsw_ciphertext; pub mod lwe_bootstrap_key; @@ -25,6 +27,7 @@ pub mod lwe_private_functional_packing_keyswitch_key; pub mod lwe_private_functional_packing_keyswitch_key_list; pub mod lwe_public_key; pub mod lwe_secret_key; +pub mod lwe_trace_packing_keyswitch_key; pub mod ntt_ggsw_ciphertext; pub mod ntt_ggsw_ciphertext_list; pub mod ntt_lwe_bootstrap_key; @@ -68,6 +71,8 @@ pub use ggsw_ciphertext::*; pub use ggsw_ciphertext_list::*; pub use glwe_ciphertext::*; pub use glwe_ciphertext_list::*; +pub use glwe_keyswitch_key::*; +pub use glwe_relinearization_key::*; pub use glwe_secret_key::*; pub use gsw_ciphertext::*; pub use lwe_bootstrap_key::*; @@ -82,6 +87,7 @@ pub use lwe_private_functional_packing_keyswitch_key::*; pub use lwe_private_functional_packing_keyswitch_key_list::*; pub use lwe_public_key::*; pub use lwe_secret_key::*; +pub use lwe_trace_packing_keyswitch_key::*; pub use ntt_ggsw_ciphertext::*; pub use ntt_ggsw_ciphertext_list::*; pub use ntt_lwe_bootstrap_key::*;