Skip to content

Commit

Permalink
feat(optimizer): constrain optimizer with external max_variance
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 9, 2024
1 parent a1398fd commit 72d5b7b
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,8 @@ mod tests {
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
max_variance: 0.0_f64,
variance: 0.0_f64,
};
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number(), Location::Unknown);
Expand Down Expand Up @@ -956,6 +958,8 @@ mod tests {
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
max_variance: 0.0_f64,
variance: 0.0_f64,
};
let change_part =
builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct VariancedDag {
pub(crate) dag: Dag,
pub(crate) partitions: Partitions,
pub(crate) variances: Variances,
pub(crate) external_variance_constraints: Vec<VarianceConstraint>,
}

impl VariancedDag {
Expand All @@ -133,9 +134,11 @@ impl VariancedDag {
dag,
partitions,
variances,
external_variance_constraints: vec![],
};

// We forward the noise once to verify the composability.
varianced.apply_external_partition_input_variance();
let _ = varianced.forward_noise();
varianced.check_composability()?;
varianced.apply_composition_rules();
Expand All @@ -145,6 +148,8 @@ impl VariancedDag {
// The noise gets computed from inputs down to outputs.
if varianced.forward_noise() {
// Noise settled, we return the varianced dag.
varianced.collect_external_input_constraint();
varianced.collect_external_output_constraint();
return Ok(varianced);
}
// The noise of the inputs gets updated following the composition rules
Expand Down Expand Up @@ -181,6 +186,106 @@ impl VariancedDag {
}
}

fn apply_external_partition_input_variance(&mut self) {
let p_cut = self.partitions.p_cut.clone();
for (i, op) in self.dag.operators.clone().iter().enumerate() {
if let Operator::Input { .. } = op {
let partition_index = self.partitions.instrs_partition[i].instruction_partition;
if p_cut.is_external_partition(&partition_index) {
let partitions = self.partitions.clone();
let external_partition =
&p_cut.external_partitions[p_cut.external_partition_index(partition_index)];
let max_variance = external_partition.max_variance;
let variance = external_partition.variance;

let mut input = self.get_operator_mut(OperatorIndex(i));
let mut variances = input.variance().clone();
variances.vars[partition_index.0] = SymbolicVariance::from_external_partition(
partitions.nb_partitions,
partition_index,
max_variance / variance,
);
*(input.variance_mut()) = variances;
}
}
}
}

fn collect_external_input_constraint(&mut self) {
let p_cut = &self.partitions.p_cut;
for (i, op) in self.dag.operators.clone().iter().enumerate() {
if let Operator::Input {
out_precision,
out_shape,
} = op
{
let partition_index = self.partitions.instrs_partition[i].instruction_partition;
if !p_cut.is_external_partition(&partition_index) {
continue;
}

let max_variance = p_cut.external_partitions
[p_cut.external_partition_index(partition_index)]
.max_variance;

let variances = &self.get_operator(OperatorIndex(i)).variance().vars.clone();
for (i, variance) in variances.iter().enumerate() {
if variance.coeffs.is_nan() {
assert!(i != partition_index.0);
continue;
}
let constraint = VarianceConstraint {
precision: *out_precision,
partition: partition_index,
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
variance: variance.clone(),
};
self.external_variance_constraints.push(constraint);
}
}
}
}

fn collect_external_output_constraint(&mut self) {
let p_cut = self.partitions.p_cut.clone();
for dag_op in self.dag.get_output_operators_iter() {
let DagOperator {
id: op_index,
shape: out_shape,
precision: out_precision,
..
} = dag_op;
let optional_partition_index = p_cut.partition(&self.dag, op_index);
if optional_partition_index.is_none() {
continue;
}
let partition_index = optional_partition_index.unwrap();
if !p_cut.is_external_partition(&partition_index) {
continue;
}
let max_variance = p_cut.external_partitions
[p_cut.external_partition_index(partition_index)]
.max_variance;

let variances = &self.get_operator(op_index).variance().vars.clone();
for (i, variance) in variances.iter().enumerate() {
if variance.coeffs.is_nan() {
assert!(i != partition_index.0);
continue;
}
let constraint = VarianceConstraint {
precision: *out_precision,
partition: partition_index,
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
variance: variance.clone(),
};
self.external_variance_constraints.push(constraint);
}
}
}

/// Propagates the noise downward in the graph.
fn forward_noise(&mut self) -> bool {
// We save the old variance to compute the diff at the end.
Expand Down Expand Up @@ -343,7 +448,9 @@ pub fn analyze(
let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition);
let partitioned_dag = PartitionedDag { dag, partitions };
let varianced_dag = VariancedDag::try_from_partitioned(partitioned_dag)?;
let variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config);
let mut variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config);
// add external variance constraints
variance_constraints.extend_from_slice(varianced_dag.external_variance_constraints.as_slice());
let undominated_variance_constraints =
VarianceConstraint::remove_dominated(&variance_constraints);
let operations_count_per_instrs = collect_operations_count(&varianced_dag);
Expand Down Expand Up @@ -560,6 +667,7 @@ fn collect_all_variance_constraints(
dag,
partitions,
variances,
..
} = dag;
let mut constraints = vec![];
for op in dag.get_operators_iter() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ impl OperationsValue {
}
}

pub fn is_nan(&self) -> bool {
for val in self.values.iter() {
if !val.is_nan() {
return false;
}
}
true
}

pub fn input(&mut self, partition: PartitionIndex) -> &mut f64 {
&mut self.values[self.index.input(partition)]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,13 @@ use optimization::dag::multi_parameters::partition_cut::ExternalPartition;

use super::*;
use crate::computing_cost::cpu::CpuComplexity;
use crate::config;
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
use crate::dag::unparametrized;
use crate::optimization::dag::multi_parameters::partitionning::tests::{
get_tfhers_noise_br, SHARED_CACHES, TFHERS_MACRO_PARAMS,
};
use crate::optimization::dag::solo_key;
use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag};
use crate::optimization::decomposition;

const CIPHERTEXT_MODULUS_LOG: u32 = 64;
const FFT_PRECISION: u32 = 53;

static SHARED_CACHES: Lazy<PersistDecompCaches> = Lazy::new(|| {
let processing_unit = config::ProcessingUnit::Cpu;
decomposition::cache(
128,
processing_unit,
None,
true,
CIPHERTEXT_MODULUS_LOG,
FFT_PRECISION,
)
});

const _4_SIGMA: f64 = 0.000_063_342_483_999_973;

Expand Down Expand Up @@ -853,22 +839,17 @@ fn test_bug_with_zero_noise() {
assert!(sol.is_some());
}

const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: 11,
glwe_dimension: 1,
},
internal_dim: 887,
};

#[test]
fn test_optimize_tfhers_in_out_dot_compute() {
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let variance = get_tfhers_noise_br();
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 2.0,
variance,
};
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None);
let dot = dag.add_dot([change_part1], [2]);
_ = dag.add_change_partition(dot, None, Some(&tfhers_partition));
Expand All @@ -880,17 +861,22 @@ fn test_optimize_tfhers_in_out_dot_compute() {

#[test]
fn test_optimize_tfhers_2lut_compute() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 11;
let input = dag.add_input(tfhers_precision, Shape::number());
let variance = get_tfhers_noise_br();
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 4.0,
variance,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 4.0,
variance,
};
let tfhers_precision = 11;
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
Expand All @@ -902,17 +888,22 @@ fn test_optimize_tfhers_2lut_compute() {

#[test]
fn test_optimize_tfhers_different_in_out_2lut_compute() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 8;
let input = dag.add_input(tfhers_precision, Shape::number());
let variance = get_tfhers_noise_br();
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers_in"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 6.0,
variance,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers_out"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 6.0,
variance,
};
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 8;
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
Expand All @@ -922,11 +913,78 @@ fn test_optimize_tfhers_different_in_out_2lut_compute() {
assert!(sol.is_some());
}

#[test]
fn test_optimize_tfhers_input_constraints() {
let variances = [1.0, 6.14e-14, 2.14e-16];
let dag_builder = |variance: f64| -> Dag {
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 4.0,
variance,
};
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 4;
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, tfhers_precision);
let out = dag.add_dot([lut], [128]);
dag.add_composition(out, input);
dag
};

let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0));
assert!(sol.is_some());
let mut last_complexity = sol.unwrap().complexity;
for variance in &variances[1..] {
let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0));
assert!(sol.is_some());
let complexity = sol.unwrap().complexity;
assert!(complexity > last_complexity);
last_complexity = complexity;
}
}

#[test]
fn test_optimize_tfhers_output_constraints() {
let variances = [1.0, 6.14e-14, 2.14e-16];
let dag_builder = |variance: f64| -> Dag {
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 4.0,
variance,
};
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 4;
let input = dag.add_input(tfhers_precision, Shape::number());
let lut = dag.add_lut(input, FunctionTable::UNKWOWN, tfhers_precision);
let dot = dag.add_dot([lut], [128]);
let out = dag.add_change_partition(dot, None, Some(&tfhers_partition));
dag.add_composition(out, input);
dag
};

let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0));
assert!(sol.is_some());
let mut last_complexity = sol.unwrap().complexity;
for variance in &variances[1..] {
let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0));
assert!(sol.is_some());
let complexity = sol.unwrap().complexity;
assert!(complexity > last_complexity);
last_complexity = complexity;
}
}

#[test]
fn test_optimize_tfhers_to_concrete_and_back_example() {
let variance = get_tfhers_noise_br();
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 8.0,
variance,
};
let concrete_precision = 8;
let msg_width = 2;
Expand Down
Loading

0 comments on commit 72d5b7b

Please sign in to comment.