Skip to content

Commit

Permalink
chore(gpu): update multi-bit params, add noise test for the classical…
Browse files Browse the repository at this point in the history
… & multi-bit PBS
  • Loading branch information
agnesLeroy committed Dec 19, 2024
1 parent bda8ab0 commit 622ac8f
Show file tree
Hide file tree
Showing 12 changed files with 1,077 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use super::*;
use crate::core_crypto::commons::noise_formulas::lwe_multi_bit_programmable_bootstrap::multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul;
use crate::core_crypto::commons::noise_formulas::secure_noise::minimal_lwe_variance_for_132_bits_security_gaussian;
use crate::core_crypto::commons::test_tools::{torus_modular_diff, variance};
use rayon::prelude::*;

// This is 1 / 16 which is exactly representable in an f64 (even an f32)
// 1 / 32 is too strict and fails the tests
const RELATIVE_TOLERANCE: f64 = 0.0625;

const NB_TESTS: usize = 1000;

fn lwe_encrypt_multi_bit_pbs_group_3_decrypt_custom_mod<Scalar>(params: MultiBitTestParams<Scalar>)
where
Scalar: UnsignedTorus + Sync + Send + CastFrom<usize> + CastInto<usize>,
{
let input_lwe_dimension = params.input_lwe_dimension;
let lwe_noise_distribution = params.lwe_noise_distribution;
let glwe_noise_distribution = params.glwe_noise_distribution;
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;
let pbs_decomposition_base_log = params.decomp_base_log;
let pbs_decomposition_level_count = params.decomp_level_count;
let grouping_factor = params.grouping_factor;
assert_eq!(grouping_factor.0, 3);

let modulus_as_f64 = if ciphertext_modulus.is_native_modulus() {
2.0f64.powi(Scalar::BITS as i32)
} else {
ciphertext_modulus.get_custom_modulus() as f64
};

let expected_variance = multi_bit_pbs_variance_132_bits_security_gaussian_gf_3_fft_mul(
input_lwe_dimension,
glwe_dimension,
polynomial_size,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
modulus_as_f64,
);

let mut rsc = TestResources::new();

let f = |x: Scalar| x;

let delta: Scalar = encoding_with_padding / msg_modulus;
let mut msg = msg_modulus;

let num_samples = NB_TESTS * <Scalar as CastInto<usize>>::cast_into(msg);
let mut noise_samples = Vec::with_capacity(num_samples);

let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
input_lwe_dimension,
&mut rsc.secret_random_generator,
);

let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut rsc.secret_random_generator,
);

let output_lwe_secret_key = output_glwe_secret_key.as_lwe_secret_key();

let fbsk = {
let bsk = allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
&input_lwe_secret_key,
&output_glwe_secret_key,
pbs_decomposition_base_log,
pbs_decomposition_level_count,
grouping_factor,
glwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

assert!(check_encrypted_content_respects_mod(
&*bsk,
ciphertext_modulus
));

let mut fbsk = FourierLweMultiBitBootstrapKey::new(
bsk.input_lwe_dimension(),
bsk.glwe_size(),
bsk.polynomial_size(),
bsk.decomposition_base_log(),
bsk.decomposition_level_count(),
bsk.grouping_factor(),
);

par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut fbsk);

fbsk
};

let accumulator = generate_programmable_bootstrap_glwe_lut(
polynomial_size,
glwe_dimension.to_glwe_size(),
msg_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);

assert!(check_encrypted_content_respects_mod(
&accumulator,
ciphertext_modulus
));

while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);

let current_run_samples: Vec<_> = (0..NB_TESTS)
.into_par_iter()
.map(|_| {
let mut rsc = TestResources::new();

let plaintext = Plaintext(msg * delta);

let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
plaintext,
lwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

assert!(check_encrypted_content_respects_mod(
&lwe_ciphertext_in,
ciphertext_modulus
));

let mut out_pbs_ct = LweCiphertext::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
ciphertext_modulus,
);

multi_bit_programmable_bootstrap_lwe_ciphertext(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator,
&fbsk,
params.thread_count,
true,
);

assert!(check_encrypted_content_respects_mod(
&out_pbs_ct,
ciphertext_modulus
));

let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);

let decoded = round_decode(decrypted.0, delta) % msg_modulus;

assert_eq!(decoded, f(msg));

torus_modular_diff(plaintext.0, decrypted.0, ciphertext_modulus)
})
.collect();

noise_samples.extend(current_run_samples);
}

let measured_variance = variance(&noise_samples);

let minimal_variance = minimal_lwe_variance_for_132_bits_security_gaussian(
fbsk.output_lwe_dimension(),
if ciphertext_modulus.is_native_modulus() {
2.0f64.powi(Scalar::BITS as i32)
} else {
ciphertext_modulus.get_custom_modulus() as f64
},
);

// Have a log even if it's a test to have a trace in no capture mode to eyeball variances
println!("measured_variance={measured_variance:?}");
println!("expected_variance={expected_variance:?}");
println!("minimal_variance={minimal_variance:?}");

if measured_variance.0 < expected_variance.0 {
// We are in the clear as long as we have at least the noise for security
assert!(
measured_variance.0 >= minimal_variance.0,
"Found insecure variance after PBS\n\
measure_variance={measured_variance:?}\n\
minimal_variance={minimal_variance:?}"
);
} else {
// Check we are not too far from the expected variance if we are bigger
let var_abs_diff = (expected_variance.0 - measured_variance.0).abs();
let tolerance_threshold = RELATIVE_TOLERANCE * expected_variance.0;

assert!(
var_abs_diff < tolerance_threshold,
"Absolute difference for variance: {var_abs_diff}, \
tolerance threshold: {tolerance_threshold}, \
got variance: {measured_variance:?}, \
expected variance: {expected_variance:?}"
);
}
}

create_parameterized_test!(lwe_encrypt_multi_bit_pbs_group_3_decrypt_custom_mod {
NOISE_TEST_PARAMS_MULTI_BIT_GROUP_3_4_BITS_NATIVE_U64_132_BITS_GAUSSIAN
});
84 changes: 84 additions & 0 deletions tfhe/src/core_crypto/algorithms/test/noise_distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::*;

mod lwe_encryption_noise;
mod lwe_keyswitch_noise;
mod lwe_multi_bit_programmable_bootstrapping_noise;
mod lwe_programmable_bootstrapping_noise;

#[allow(clippy::excessive_precision)]
Expand All @@ -28,3 +29,86 @@ pub const NOISE_TEST_PARAMS_4_BITS_NATIVE_U64_132_BITS_GAUSSIAN: ClassicTestPara
message_modulus_log: MessageModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
};

// ---- GAUSSIAN ---------------------------------------------------------

#[allow(clippy::excessive_precision)]
#[allow(dead_code)]
pub const NOISE_TEST_PARAMS_MULTI_BIT_GROUP_3_2_BITS_NATIVE_U64_132_BITS_GAUSSIAN:
MultiBitTestParams<u64> = MultiBitTestParams {
input_lwe_dimension: LweDimension(256 * 3),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
1.1098369627275701e-05,
)),
decomp_base_log: DecompositionBaseLog(17),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(3),
polynomial_size: PolynomialSize(512),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
1.9524392655548086e-11,
)),
message_modulus_log: MessageModulusLog(2),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
};

#[allow(clippy::excessive_precision)]
pub const NOISE_TEST_PARAMS_MULTI_BIT_GROUP_3_4_BITS_NATIVE_U64_132_BITS_GAUSSIAN:
MultiBitTestParams<u64> = MultiBitTestParams {
input_lwe_dimension: LweDimension(279 * 3),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
3.3747142481837397e-06,
)),
decomp_base_log: DecompositionBaseLog(22),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
2.845267479601915e-15,
)),
message_modulus_log: MessageModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
};

#[allow(clippy::excessive_precision)]
#[allow(dead_code)]
pub const NOISE_TEST_PARAMS_MULTI_BIT_GROUP_3_6_BITS_NATIVE_U64_132_BITS_GAUSSIAN:
MultiBitTestParams<u64> = MultiBitTestParams {
input_lwe_dimension: LweDimension(326 * 3),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
2.962875621642539e-07,
)),
decomp_base_log: DecompositionBaseLog(14),
decomp_level_count: DecompositionLevelCount(2),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(8192),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
2.168404344971009e-19,
)),
message_modulus_log: MessageModulusLog(6),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
};

// ---- TUNIFORM ---------------------------------------------------------

#[allow(clippy::excessive_precision)]
#[allow(dead_code)]
pub const NOISE_TEST_PARAMS_MULTI_BIT_GROUP_3_4_BITS_NATIVE_U64_132_BITS_TUNIFORM:
MultiBitTestParams<u64> = MultiBitTestParams {
input_lwe_dimension: LweDimension(295 * 3), // 295 after FFT better fix
lwe_noise_distribution: DynamicDistribution::new_t_uniform(46),
decomp_base_log: DecompositionBaseLog(22),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
glwe_noise_distribution: DynamicDistribution::new_t_uniform(17),
message_modulus_log: MessageModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
};
Loading

0 comments on commit 622ac8f

Please sign in to comment.