From 72d5b7bbb42a34b9a85a2864d71091016286a407 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 7 Aug 2024 08:27:47 +0100 Subject: [PATCH] feat(optimizer): constrain optimizer with external max_variance --- .../src/dag/unparametrized.rs | 4 + .../dag/multi_parameters/analyze.rs | 110 +++++++++++++- .../dag/multi_parameters/operations_value.rs | 9 ++ .../dag/multi_parameters/optimize/tests.rs | 136 +++++++++++++----- .../dag/multi_parameters/partition_cut.rs | 27 +++- .../dag/multi_parameters/partitionning.rs | 103 ++++++++++--- .../dag/multi_parameters/partitions.rs | 3 + .../dag/multi_parameters/symbolic_variance.rs | 14 ++ .../src/optimization/decomposition/cmux.rs | 30 ++++ 9 files changed, 371 insertions(+), 65 deletions(-) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 3f021756f8..960f4ce98c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -905,6 +905,8 @@ mod tests { let tfhers_part = ExternalPartition { name: String::from("tfhers"), macro_params: DUMMY_MACRO_PARAM, + max_variance: 0.0_f64, + variance: 0.0_f64, }; let mut builder = graph.builder("main1"); let a = builder.add_input(1, Shape::number(), Location::Unknown); @@ -956,6 +958,8 @@ mod tests { let tfhers_part = ExternalPartition { name: String::from("tfhers"), macro_params: DUMMY_MACRO_PARAM, + max_variance: 0.0_f64, + variance: 0.0_f64, }; let change_part = builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 9cddc67cf3..6e7662edf2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -121,6 +121,7 @@ pub struct VariancedDag { pub(crate) dag: Dag, pub(crate) partitions: Partitions, pub(crate) variances: Variances, + pub(crate) external_variance_constraints: Vec, } impl VariancedDag { @@ -133,9 +134,11 @@ impl VariancedDag { dag, partitions, variances, + external_variance_constraints: vec![], }; // We forward the noise once to verify the composability. + varianced.apply_external_partition_input_variance(); let _ = varianced.forward_noise(); varianced.check_composability()?; varianced.apply_composition_rules(); @@ -145,6 +148,8 @@ impl VariancedDag { // The noise gets computed from inputs down to outputs. if varianced.forward_noise() { // Noise settled, we return the varianced dag. + varianced.collect_external_input_constraint(); + varianced.collect_external_output_constraint(); return Ok(varianced); } // The noise of the inputs gets updated following the composition rules @@ -181,6 +186,106 @@ impl VariancedDag { } } + fn apply_external_partition_input_variance(&mut self) { + let p_cut = self.partitions.p_cut.clone(); + for (i, op) in self.dag.operators.clone().iter().enumerate() { + if let Operator::Input { .. } = op { + let partition_index = self.partitions.instrs_partition[i].instruction_partition; + if p_cut.is_external_partition(&partition_index) { + let partitions = self.partitions.clone(); + let external_partition = + &p_cut.external_partitions[p_cut.external_partition_index(partition_index)]; + let max_variance = external_partition.max_variance; + let variance = external_partition.variance; + + let mut input = self.get_operator_mut(OperatorIndex(i)); + let mut variances = input.variance().clone(); + variances.vars[partition_index.0] = SymbolicVariance::from_external_partition( + partitions.nb_partitions, + partition_index, + max_variance / variance, + ); + *(input.variance_mut()) = variances; + } + } + } + } + + fn collect_external_input_constraint(&mut self) { + let p_cut = &self.partitions.p_cut; + for (i, op) in self.dag.operators.clone().iter().enumerate() { + if let Operator::Input { + out_precision, + out_shape, + } = op + { + let partition_index = self.partitions.instrs_partition[i].instruction_partition; + if !p_cut.is_external_partition(&partition_index) { + continue; + } + + let max_variance = p_cut.external_partitions + [p_cut.external_partition_index(partition_index)] + .max_variance; + + let variances = &self.get_operator(OperatorIndex(i)).variance().vars.clone(); + for (i, variance) in variances.iter().enumerate() { + if variance.coeffs.is_nan() { + assert!(i != partition_index.0); + continue; + } + let constraint = VarianceConstraint { + precision: *out_precision, + partition: partition_index, + nb_constraints: out_shape.flat_size(), + safe_variance_bound: max_variance, + variance: variance.clone(), + }; + self.external_variance_constraints.push(constraint); + } + } + } + } + + fn collect_external_output_constraint(&mut self) { + let p_cut = self.partitions.p_cut.clone(); + for dag_op in self.dag.get_output_operators_iter() { + let DagOperator { + id: op_index, + shape: out_shape, + precision: out_precision, + .. + } = dag_op; + let optional_partition_index = p_cut.partition(&self.dag, op_index); + if optional_partition_index.is_none() { + continue; + } + let partition_index = optional_partition_index.unwrap(); + if !p_cut.is_external_partition(&partition_index) { + continue; + } + let max_variance = p_cut.external_partitions + [p_cut.external_partition_index(partition_index)] + .max_variance; + + let variances = &self.get_operator(op_index).variance().vars.clone(); + for (i, variance) in variances.iter().enumerate() { + if variance.coeffs.is_nan() { + assert!(i != partition_index.0); + continue; + } + let constraint = VarianceConstraint { + precision: *out_precision, + partition: partition_index, + nb_constraints: out_shape.flat_size(), + safe_variance_bound: max_variance, + variance: variance.clone(), + }; + self.external_variance_constraints.push(constraint); + } + } + } + /// Propagates the noise downward in the graph. fn forward_noise(&mut self) -> bool { // We save the old variance to compute the diff at the end. @@ -343,7 +448,9 @@ pub fn analyze( let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition); let partitioned_dag = PartitionedDag { dag, partitions }; let varianced_dag = VariancedDag::try_from_partitioned(partitioned_dag)?; - let variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config); + let mut variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config); + // add external variance constraints + variance_constraints.extend_from_slice(varianced_dag.external_variance_constraints.as_slice()); let undominated_variance_constraints = VarianceConstraint::remove_dominated(&variance_constraints); let operations_count_per_instrs = collect_operations_count(&varianced_dag); @@ -560,6 +667,7 @@ fn collect_all_variance_constraints( dag, partitions, variances, + .. } = dag; let mut constraints = vec![]; for op in dag.get_operators_iter() { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs index e70b214266..7b0b90b268 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs @@ -193,6 +193,15 @@ impl OperationsValue { } } + pub fn is_nan(&self) -> bool { + for val in self.values.iter() { + if !val.is_nan() { + return false; + } + } + true + } + pub fn input(&mut self, partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.input(partition)] } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 7f50d22c19..9fc36062aa 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -5,27 +5,13 @@ use optimization::dag::multi_parameters::partition_cut::ExternalPartition; use super::*; use crate::computing_cost::cpu::CpuComplexity; -use crate::config; use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; use crate::dag::unparametrized; +use crate::optimization::dag::multi_parameters::partitionning::tests::{ + get_tfhers_noise_br, SHARED_CACHES, TFHERS_MACRO_PARAMS, +}; use crate::optimization::dag::solo_key; use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag}; -use crate::optimization::decomposition; - -const CIPHERTEXT_MODULUS_LOG: u32 = 64; -const FFT_PRECISION: u32 = 53; - -static SHARED_CACHES: Lazy = Lazy::new(|| { - let processing_unit = config::ProcessingUnit::Cpu; - decomposition::cache( - 128, - processing_unit, - None, - true, - CIPHERTEXT_MODULUS_LOG, - FFT_PRECISION, - ) -}); const _4_SIGMA: f64 = 0.000_063_342_483_999_973; @@ -853,22 +839,17 @@ fn test_bug_with_zero_noise() { assert!(sol.is_some()); } -const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters { - glwe_params: GlweParameters { - log2_polynomial_size: 11, - glwe_dimension: 1, - }, - internal_dim: 887, -}; - #[test] fn test_optimize_tfhers_in_out_dot_compute() { - let mut dag = unparametrized::Dag::new(); - let input1 = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input1 = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None); let dot = dag.add_dot([change_part1], [2]); _ = dag.add_change_partition(dot, None, Some(&tfhers_partition)); @@ -880,17 +861,22 @@ fn test_optimize_tfhers_in_out_dot_compute() { #[test] fn test_optimize_tfhers_2lut_compute() { - let mut dag = unparametrized::Dag::new(); - let tfhers_precision = 11; - let input = dag.add_input(tfhers_precision, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition_in = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; let tfhers_partition_out = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; + let tfhers_precision = 11; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(tfhers_precision, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); @@ -902,17 +888,22 @@ fn test_optimize_tfhers_2lut_compute() { #[test] fn test_optimize_tfhers_different_in_out_2lut_compute() { - let mut dag = unparametrized::Dag::new(); - let tfhers_precision = 8; - let input = dag.add_input(tfhers_precision, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition_in = ExternalPartition { name: String::from("tfhers_in"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, }; let tfhers_partition_out = ExternalPartition { name: String::from("tfhers_out"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 8; + let input = dag.add_input(tfhers_precision, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); @@ -922,11 +913,78 @@ fn test_optimize_tfhers_different_in_out_2lut_compute() { assert!(sol.is_some()); } +#[test] +fn test_optimize_tfhers_input_constraints() { + let variances = [1.0, 6.14e-14, 2.14e-16]; + let dag_builder = |variance: f64| -> Dag { + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 4; + let input = dag.add_input(tfhers_precision, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); + let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, tfhers_precision); + let out = dag.add_dot([lut], [128]); + dag.add_composition(out, input); + dag + }; + + let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let mut last_complexity = sol.unwrap().complexity; + for variance in &variances[1..] { + let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let complexity = sol.unwrap().complexity; + assert!(complexity > last_complexity); + last_complexity = complexity; + } +} + +#[test] +fn test_optimize_tfhers_output_constraints() { + let variances = [1.0, 6.14e-14, 2.14e-16]; + let dag_builder = |variance: f64| -> Dag { + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 4; + let input = dag.add_input(tfhers_precision, Shape::number()); + let lut = dag.add_lut(input, FunctionTable::UNKWOWN, tfhers_precision); + let dot = dag.add_dot([lut], [128]); + let out = dag.add_change_partition(dot, None, Some(&tfhers_partition)); + dag.add_composition(out, input); + dag + }; + + let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let mut last_complexity = sol.unwrap().complexity; + for variance in &variances[1..] { + let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let complexity = sol.unwrap().complexity; + assert!(complexity > last_complexity); + last_complexity = complexity; + } +} + #[test] fn test_optimize_tfhers_to_concrete_and_back_example() { + let variance = get_tfhers_noise_br(); let tfhers_partition = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 8.0, + variance, }; let concrete_precision = 8; let msg_width = 2; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index 8eb11868da..dbd5edc22f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -14,10 +14,29 @@ use super::optimize::MacroParameters; const ROUND_INNER_MULTI_PARAMETER: bool = false; const ROUND_EXTERNAL_MULTI_PARAMETER: bool = !ROUND_INNER_MULTI_PARAMETER && true; -#[derive(Hash, Eq, PartialEq, Clone, Debug)] +#[derive(Clone, Debug)] pub struct ExternalPartition { pub name: String, pub macro_params: MacroParameters, + pub max_variance: f64, + pub variance: f64, +} + +impl Eq for ExternalPartition {} + +impl PartialEq for ExternalPartition { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.macro_params == other.macro_params + && self.max_variance == other.max_variance + } +} + +impl std::hash::Hash for ExternalPartition { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.macro_params.hash(state); + } } impl std::fmt::Display for ExternalPartition { @@ -70,8 +89,12 @@ impl PartitionCut { self.external_partitions.len() } + pub fn external_partition_index(&self, partition: PartitionIndex) -> usize { + partition.0 - self.n_internal_partitions() + } + pub fn is_external_partition(&self, partition: &PartitionIndex) -> bool { - partition.0 >= self.n_internal_partitions() + partition.0 >= self.n_internal_partitions() && partition.0 < self.n_partitions() } pub fn is_internal_partition(&self, partition: &PartitionIndex) -> bool { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index 1daf7c25e6..d647163c80 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -132,6 +132,7 @@ fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { Partitions { nb_partitions: 1, instrs_partition, + p_cut: PartitionCut::empty(), } } @@ -245,6 +246,7 @@ fn resolve_by_levelled_block( Partitions { nb_partitions, instrs_partition: instrs_p, + p_cut: p_cut.clone(), } // Now we can generate transitions // Input has no transtions @@ -271,13 +273,33 @@ pub mod tests { pub const LOW_PRECISION_PARTITION: PartitionIndex = PartitionIndex(0); pub const HIGH_PRECISION_PARTITION: PartitionIndex = PartitionIndex(1); + use once_cell::sync::Lazy; + use super::*; + use crate::config; use crate::dag::operator::{FunctionTable, Shape, Weights}; use crate::dag::unparametrized; use crate::optimization::dag::multi_parameters::optimize::MacroParameters; use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition; + use crate::optimization::decomposition::cmux::get_noise_br; + use crate::optimization::decomposition::{self, PersistDecompCaches}; use crate::parameters::GlweParameters; + const CIPHERTEXT_MODULUS_LOG: u32 = 64; + const FFT_PRECISION: u32 = 53; + + pub static SHARED_CACHES: Lazy = Lazy::new(|| { + let processing_unit = config::ProcessingUnit::Cpu; + decomposition::cache( + 128, + processing_unit, + None, + true, + CIPHERTEXT_MODULUS_LOG, + FFT_PRECISION, + ) + }); + fn default_p_cut() -> PartitionCut { PartitionCut::from_precisions(&[2, 128]) } @@ -359,22 +381,41 @@ pub mod tests { PartitionIndex(default) } - const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters { - glwe_params: GlweParameters { - log2_polynomial_size: 0, - glwe_dimension: 0, - }, - internal_dim: 0, + pub const GLWE_PARAMS: GlweParameters = GlweParameters { + log2_polynomial_size: 11, + glwe_dimension: 1, + }; + + pub const TFHERS_PBS_LEVEL: u64 = 1; + + pub const TFHERS_MACRO_PARAMS: MacroParameters = MacroParameters { + glwe_params: GLWE_PARAMS, + internal_dim: 841, }; + pub fn get_tfhers_noise_br() -> f64 { + get_noise_br( + SHARED_CACHES.caches(), + GLWE_PARAMS.log2_polynomial_size, + GLWE_PARAMS.glwe_dimension, + TFHERS_MACRO_PARAMS.internal_dim, + TFHERS_PBS_LEVEL, + None, + ) + .unwrap() + } + #[test] fn test_tfhers_in_out_dot_compute() { - let mut dag = unparametrized::Dag::new(); - let input1 = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input1 = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None); let dot = dag.add_dot([change_part1], [2]); _ = dag.add_change_partition(dot, None, Some(&tfhers_partition)); @@ -388,12 +429,15 @@ pub mod tests { #[test] fn test_tfhers_in_out_lut_compute() { - let mut dag = unparametrized::Dag::new(); - let input = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition)); @@ -424,16 +468,21 @@ pub mod tests { #[test] fn test_tfhers_different_in_out_lut_compute() { - let mut dag = unparametrized::Dag::new(); - let input = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition_in = ExternalPartition { name: String::from("tfhers_in"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, }; let tfhers_partition_out = ExternalPartition { name: String::from("tfhers_out"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition_out)); @@ -468,12 +517,15 @@ pub mod tests { #[test] fn test_tfhers_in_out_2lut_compute() { - let mut dag = unparametrized::Dag::new(); - let input = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition = ExternalPartition { name: String::from("tfhers"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); @@ -518,16 +570,21 @@ pub mod tests { #[test] fn test_tfhers_different_in_out_2lut_compute() { - let mut dag = unparametrized::Dag::new(); - let input = dag.add_input(16, Shape::number()); + let variance = get_tfhers_noise_br(); let tfhers_partition_in = ExternalPartition { name: String::from("tfhers_in"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; let tfhers_partition_out = ExternalPartition { name: String::from("tfhers_out"), - macro_params: DUMMY_MACRO_PARAM, + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs index 28ee90c61b..18dd8050da 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -6,6 +6,8 @@ use std::{ use crate::dag::operator::OperatorIndex; +use super::partition_cut::PartitionCut; + #[derive(Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Copy)] pub struct PartitionIndex(pub(crate) usize); @@ -78,6 +80,7 @@ impl InstructionPartition { pub struct Partitions { pub nb_partitions: usize, pub instrs_partition: Vec, + pub p_cut: PartitionCut, } impl Index for Partitions { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index 664088bda2..7b120a48a6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -54,6 +54,20 @@ impl SymbolicVariance { r } + pub fn from_external_partition( + nb_partitions: usize, + partition: PartitionIndex, + max_variance: f64, + ) -> Self { + let mut r = Self { + partition, + coeffs: OperationsValue::zero(nb_partitions), + }; + // rust ..., offset cannot be inlined + *r.coeffs.pbs(partition) = max_variance; + r + } + pub fn coeff_input(&self, partition: PartitionIndex) -> f64 { self.coeffs[self.coeffs.index.input(partition)] } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs index fc28b5f89e..41a44e6edf 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use super::common::VERSION; +use super::DecompCaches; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct CmuxComplexityNoise { @@ -147,3 +148,32 @@ pub fn cache( }; PersistentCacheHashMap::new_no_read(&path, VERSION, function) } + +#[derive(Debug)] +pub enum MaxVarianceError { + PbsBaseLogNotFound, + PbsLevelNotFound, +} + +pub fn get_noise_br( + mut cache: DecompCaches, + log2_polynomial_size: u64, + glwe_dimension: u64, + lwe_dim: u64, + pbs_level: u64, + pbs_log2_base: Option, +) -> Result { + let cmux_quantities = cache.cmux.pareto_quantities(GlweParameters { + log2_polynomial_size, + glwe_dimension, + }); + for cmux_quantity in cmux_quantities { + if cmux_quantity.decomp.level == pbs_level { + if pbs_log2_base.is_some() && cmux_quantity.decomp.log2_base == pbs_log2_base.unwrap() { + return Err(MaxVarianceError::PbsBaseLogNotFound); + } + return Ok(cmux_quantity.noise_br(lwe_dim)); + } + } + Err(MaxVarianceError::PbsLevelNotFound) +}