Skip to content

Commit

Permalink
feat(optimizer): block macro parameters in external partitions
Browse files Browse the repository at this point in the history
we block them by reducing the search space to a single set of parameters
which is provided by the external partition
  • Loading branch information
youben11 committed Aug 9, 2024
1 parent 67e18ab commit a1398fd
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,20 @@ impl Dag {

#[cfg(test)]
mod tests {
use crate::{
optimization::dag::multi_parameters::optimize::MacroParameters, parameters::GlweParameters,
};

use super::*;

const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: 0,
glwe_dimension: 0,
},
internal_dim: 0,
};

#[test]
fn output_marking() {
let mut graph = Dag::new();
Expand All @@ -892,6 +904,7 @@ mod tests {
let mut graph = Dag::new();
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number(), Location::Unknown);
Expand Down Expand Up @@ -942,6 +955,7 @@ mod tests {

let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part =
builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct PartialMicroParameters {
complexity: f64,
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[derive(Hash, Debug, Copy, Clone, Eq, PartialEq)]
pub struct MacroParameters {
pub glwe_params: GlweParameters,
pub internal_dim: u64,
Expand Down Expand Up @@ -920,6 +920,11 @@ pub fn optimize(
ciphertext_modulus_log,
};

let dag_p_cut = p_cut.as_ref().map_or_else(
|| PartitionCut::for_each_precision(dag),
std::clone::Clone::clone,
);

let dag = analyze(dag, &noise_config, p_cut, default_partition)?;
let kappa =
error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability);
Expand Down Expand Up @@ -954,11 +959,27 @@ pub fn optimize(
let mut best_params: Option<Parameters> = None;
for iter in 0..=10 {
for partition in PartitionIndex::range(0, nb_partitions).rev() {
// reduce search space to the parameters of external partitions
let partition_search_space = if dag_p_cut.is_external_partition(&partition) {
let external_part =
&dag_p_cut.external_partitions[partition.0 - dag_p_cut.n_internal_partitions()];
let mut reduced_search_space = search_space.clone();
reduced_search_space.glwe_dimensions =
[external_part.macro_params.glwe_params.glwe_dimension].to_vec();
reduced_search_space.glwe_log_polynomial_sizes =
[external_part.macro_params.glwe_params.log2_polynomial_size].to_vec();
reduced_search_space.internal_lwe_dimensions =
[external_part.macro_params.internal_dim].to_vec();
reduced_search_space
} else {
search_space.clone()
};

let new_params = optimize_macro(
security_level,
ciphertext_modulus_log,
fft_precision,
search_space,
&partition_search_space,
partition,
&used_tlu_keyswitch,
&used_conversion_keyswitch,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::float_cmp)]

use once_cell::sync::Lazy;
use optimization::dag::multi_parameters::partition_cut::ExternalPartition;

use super::*;
use crate::computing_cost::cpu::CpuComplexity;
Expand Down Expand Up @@ -851,3 +852,99 @@ fn test_bug_with_zero_noise() {
let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
}

const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: 11,
glwe_dimension: 1,
},
internal_dim: 887,
};

#[test]
fn test_optimize_tfhers_in_out_dot_compute() {
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None);
let dot = dag.add_dot([change_part1], [2]);
_ = dag.add_change_partition(dot, None, Some(&tfhers_partition));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
println!("solution: {:?}", sol.unwrap());
}

#[test]
fn test_optimize_tfhers_2lut_compute() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 11;
let input = dag.add_input(tfhers_precision, Shape::number());
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
}

#[test]
fn test_optimize_tfhers_different_in_out_2lut_compute() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 8;
let input = dag.add_input(tfhers_precision, Shape::number());
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers_in"),
macro_params: DUMMY_MACRO_PARAM,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers_out"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
}

#[test]
fn test_optimize_tfhers_to_concrete_and_back_example() {
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let concrete_precision = 8;
let msg_width = 2;
let carry_width = 2;
let tfhers_precision = msg_width + carry_width;

let mut dag = unparametrized::Dag::new();
let input = dag.add_input(
tfhers_precision,
Shape::vector((concrete_precision / msg_width).into()),
);
// to concrete
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, concrete_precision);
// from concrete
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
use crate::optimization::dag::solo_key::analyze::out_variances;
use crate::optimization::dag::solo_key::symbolic_variance::SymbolicVariance;

use super::optimize::MacroParameters;

const ROUND_INNER_MULTI_PARAMETER: bool = false;
const ROUND_EXTERNAL_MULTI_PARAMETER: bool = !ROUND_INNER_MULTI_PARAMETER && true;

#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct ExternalPartition {
pub name: String,
// TODO add params (maybe just macros)
pub macro_params: MacroParameters,
}

impl std::fmt::Display for ExternalPartition {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ pub mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, Shape, Weights};
use crate::dag::unparametrized;
use crate::optimization::dag::multi_parameters::optimize::MacroParameters;
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;
use crate::parameters::GlweParameters;

fn default_p_cut() -> PartitionCut {
PartitionCut::from_precisions(&[2, 128])
Expand Down Expand Up @@ -357,12 +359,21 @@ pub mod tests {
PartitionIndex(default)
}

const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: 0,
glwe_dimension: 0,
},
internal_dim: 0,
};

#[test]
fn test_tfhers_in_out_dot_compute() {
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None);
let dot = dag.add_dot([change_part1], [2]);
Expand All @@ -381,6 +392,7 @@ pub mod tests {
let input = dag.add_input(16, Shape::number());
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16);
Expand Down Expand Up @@ -416,9 +428,11 @@ pub mod tests {
let input = dag.add_input(16, Shape::number());
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers_in"),
macro_params: DUMMY_MACRO_PARAM,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers_out"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16);
Expand Down Expand Up @@ -458,6 +472,7 @@ pub mod tests {
let input = dag.add_input(16, Shape::number());
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
Expand Down Expand Up @@ -507,9 +522,11 @@ pub mod tests {
let input = dag.add_input(16, Shape::number());
let tfhers_partition_in = ExternalPartition {
name: String::from("tfhers_in"),
macro_params: DUMMY_MACRO_PARAM,
};
let tfhers_partition_out = ExternalPartition {
name: String::from("tfhers_out"),
macro_params: DUMMY_MACRO_PARAM,
};
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
Expand Down

0 comments on commit a1398fd

Please sign in to comment.