Skip to content

Commit

Permalink
refactor(optimizer): move external partition during change_part creation
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 9, 2024
1 parent 72d5b7b commit 901ce04
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ impl<'dag> DagBuilder<'dag> {
pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
location: Location,
) -> OperatorIndex {
assert!(
Expand All @@ -274,8 +274,8 @@ impl<'dag> DagBuilder<'dag> {
self.add_operator(
Operator::ChangePartition {
input,
src_partition: src_partition.cloned(),
dst_partition: dst_partition.cloned(),
src_partition,
dst_partition,
},
location,
)
Expand Down Expand Up @@ -604,8 +604,8 @@ impl Dag {
pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
) -> OperatorIndex {
self.builder(DEFAULT_CIRCUIT).add_change_partition(
input,
Expand Down Expand Up @@ -913,13 +913,14 @@ mod tests {
let b = builder.add_input(1, Shape::number(), Location::Unknown);
let c = builder.add_dot([a, b], [1, 1], Location::Unknown);
let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown);
let _d = builder.add_change_partition(d, Some(&tfhers_part), None, Location::Unknown);
let _d =
builder.add_change_partition(d, Some(tfhers_part.clone()), None, Location::Unknown);
let mut builder = graph.builder("main2");
let e = builder.add_input(2, Shape::number(), Location::Unknown);
let f = builder.add_input(2, Shape::number(), Location::Unknown);
let g = builder.add_dot([e, f], [2, 2], Location::Unknown);
let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown);
let _h = builder.add_change_partition(h, None, Some(&tfhers_part), Location::Unknown);
let _h = builder.add_change_partition(h, None, Some(tfhers_part), Location::Unknown);
graph.tag_operator_as_output(c);
}

Expand Down Expand Up @@ -962,7 +963,7 @@ mod tests {
variance: 0.0_f64,
};
let change_part =
builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown);
builder.add_change_partition(lut2, Some(tfhers_part.clone()), None, Location::Unknown);

let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part];
for (expected_i, op_index) in ops_index.iter().enumerate() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,9 @@ fn test_optimize_tfhers_in_out_dot_compute() {
};
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None);
let dot = dag.add_dot([change_part1], [2]);
_ = dag.add_change_partition(dot, None, Some(&tfhers_partition));
_ = dag.add_change_partition(dot, None, Some(tfhers_partition.clone()));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
Expand All @@ -877,10 +877,10 @@ fn test_optimize_tfhers_2lut_compute() {
let tfhers_precision = 11;
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out));
let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
Expand All @@ -904,10 +904,10 @@ fn test_optimize_tfhers_different_in_out_2lut_compute() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 8;
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out));
let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
Expand All @@ -926,7 +926,7 @@ fn test_optimize_tfhers_input_constraints() {
let mut dag = unparametrized::Dag::new();
let tfhers_precision = 4;
let input = dag.add_input(tfhers_precision, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, tfhers_precision);
let out = dag.add_dot([lut], [128]);
dag.add_composition(out, input);
Expand Down Expand Up @@ -960,7 +960,7 @@ fn test_optimize_tfhers_output_constraints() {
let input = dag.add_input(tfhers_precision, Shape::number());
let lut = dag.add_lut(input, FunctionTable::UNKWOWN, tfhers_precision);
let dot = dag.add_dot([lut], [128]);
let out = dag.add_change_partition(dot, None, Some(&tfhers_partition));
let out = dag.add_change_partition(dot, None, Some(tfhers_partition.clone()));
dag.add_composition(out, input);
dag
};
Expand Down Expand Up @@ -997,11 +997,11 @@ fn test_optimize_tfhers_to_concrete_and_back_example() {
Shape::vector((concrete_precision / msg_width).into()),
);
// to concrete
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, concrete_precision);
// from concrete
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision);
let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition));
let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition.clone()));

let sol = optimize(&dag, &None, PartitionIndex(0));
assert!(sol.is_some());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ pub mod tests {
};
let mut dag = unparametrized::Dag::new();
let input1 = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None);
let dot = dag.add_dot([change_part1], [2]);
_ = dag.add_change_partition(dot, None, Some(&tfhers_partition));
_ = dag.add_change_partition(dot, None, Some(tfhers_partition));

let partitions = partitionning(&dag);
assert!(partitions.nb_partitions == 1);
Expand All @@ -438,9 +438,9 @@ pub mod tests {
};
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16);
let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition));
let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition));

let partitions = partitionning(&dag);
assert!(partitions.nb_partitions == 2);
Expand Down Expand Up @@ -483,9 +483,9 @@ pub mod tests {
};
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16);
let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition_out));
let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition_out.clone()));

let p_cut = PartitionCut::for_each_precision(&dag);
let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION);
Expand Down Expand Up @@ -526,10 +526,10 @@ pub mod tests {
};
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16);
let change_part2 = dag.add_change_partition(lut2, None, Some(&tfhers_partition));
let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition));

let partitions = partitionning(&dag);
assert!(partitions.nb_partitions == 3);
Expand Down Expand Up @@ -585,10 +585,10 @@ pub mod tests {
};
let mut dag = unparametrized::Dag::new();
let input = dag.add_input(16, Shape::number());
let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None);
let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None);
let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16);
let change_part2 = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out));
let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition_out.clone()));

let p_cut = PartitionCut::for_each_precision(&dag);
let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION);
Expand Down

0 comments on commit 901ce04

Please sign in to comment.