From 901ce0419f94fb34d7ae30f8c92e67cbf38b2e87 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 7 Aug 2024 16:18:28 +0100 Subject: [PATCH] refactor(optimizer): move external partition during change_part creation --- .../src/dag/unparametrized.rs | 19 +++++++++--------- .../dag/multi_parameters/optimize/tests.rs | 20 +++++++++---------- .../dag/multi_parameters/partitionning.rs | 20 +++++++++---------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 960f4ce98c..a59e0bd4c7 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -263,8 +263,8 @@ impl<'dag> DagBuilder<'dag> { pub fn add_change_partition( &mut self, input: OperatorIndex, - src_partition: Option<&ExternalPartition>, - dst_partition: Option<&ExternalPartition>, + src_partition: Option, + dst_partition: Option, location: Location, ) -> OperatorIndex { assert!( @@ -274,8 +274,8 @@ impl<'dag> DagBuilder<'dag> { self.add_operator( Operator::ChangePartition { input, - src_partition: src_partition.cloned(), - dst_partition: dst_partition.cloned(), + src_partition, + dst_partition, }, location, ) @@ -604,8 +604,8 @@ impl Dag { pub fn add_change_partition( &mut self, input: OperatorIndex, - src_partition: Option<&ExternalPartition>, - dst_partition: Option<&ExternalPartition>, + src_partition: Option, + dst_partition: Option, ) -> OperatorIndex { self.builder(DEFAULT_CIRCUIT).add_change_partition( input, @@ -913,13 +913,14 @@ mod tests { let b = builder.add_input(1, Shape::number(), Location::Unknown); let c = builder.add_dot([a, b], [1, 1], Location::Unknown); let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown); - let _d = builder.add_change_partition(d, Some(&tfhers_part), None, Location::Unknown); + let _d = + builder.add_change_partition(d, Some(tfhers_part.clone()), None, Location::Unknown); let mut builder = graph.builder("main2"); let e = builder.add_input(2, Shape::number(), Location::Unknown); let f = builder.add_input(2, Shape::number(), Location::Unknown); let g = builder.add_dot([e, f], [2, 2], Location::Unknown); let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown); - let _h = builder.add_change_partition(h, None, Some(&tfhers_part), Location::Unknown); + let _h = builder.add_change_partition(h, None, Some(tfhers_part), Location::Unknown); graph.tag_operator_as_output(c); } @@ -962,7 +963,7 @@ mod tests { variance: 0.0_f64, }; let change_part = - builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown); + builder.add_change_partition(lut2, Some(tfhers_part.clone()), None, Location::Unknown); let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part]; for (expected_i, op_index) in ops_index.iter().enumerate() { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index 9fc36062aa..5e532874af 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -850,9 +850,9 @@ fn test_optimize_tfhers_in_out_dot_compute() { }; let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None); let dot = dag.add_dot([change_part1], [2]); - _ = dag.add_change_partition(dot, None, Some(&tfhers_partition)); + _ = dag.add_change_partition(dot, None, Some(tfhers_partition.clone())); let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); @@ -877,10 +877,10 @@ fn test_optimize_tfhers_2lut_compute() { let tfhers_precision = 11; let mut dag = unparametrized::Dag::new(); let input = dag.add_input(tfhers_precision, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); - let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out)); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out)); let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); @@ -904,10 +904,10 @@ fn test_optimize_tfhers_different_in_out_2lut_compute() { let mut dag = unparametrized::Dag::new(); let tfhers_precision = 8; let input = dag.add_input(tfhers_precision, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); - let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out)); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out)); let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); @@ -926,7 +926,7 @@ fn test_optimize_tfhers_input_constraints() { let mut dag = unparametrized::Dag::new(); let tfhers_precision = 4; let input = dag.add_input(tfhers_precision, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition), None); let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, tfhers_precision); let out = dag.add_dot([lut], [128]); dag.add_composition(out, input); @@ -960,7 +960,7 @@ fn test_optimize_tfhers_output_constraints() { let input = dag.add_input(tfhers_precision, Shape::number()); let lut = dag.add_lut(input, FunctionTable::UNKWOWN, tfhers_precision); let dot = dag.add_dot([lut], [128]); - let out = dag.add_change_partition(dot, None, Some(&tfhers_partition)); + let out = dag.add_change_partition(dot, None, Some(tfhers_partition.clone())); dag.add_composition(out, input); dag }; @@ -997,11 +997,11 @@ fn test_optimize_tfhers_to_concrete_and_back_example() { Shape::vector((concrete_precision / msg_width).into()), ); // to concrete - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, concrete_precision); // from concrete let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); - let _ = dag.add_change_partition(lut2, None, Some(&tfhers_partition)); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition.clone())); let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index d647163c80..0901d809b7 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -416,9 +416,9 @@ pub mod tests { }; let mut dag = unparametrized::Dag::new(); let input1 = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input1, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None); let dot = dag.add_dot([change_part1], [2]); - _ = dag.add_change_partition(dot, None, Some(&tfhers_partition)); + _ = dag.add_change_partition(dot, None, Some(tfhers_partition)); let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 1); @@ -438,9 +438,9 @@ pub mod tests { }; let mut dag = unparametrized::Dag::new(); let input = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); - let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition)); + let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition)); let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 2); @@ -483,9 +483,9 @@ pub mod tests { }; let mut dag = unparametrized::Dag::new(); let input = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None); let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); - let change_part2 = dag.add_change_partition(lut, None, Some(&tfhers_partition_out)); + let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition_out.clone())); let p_cut = PartitionCut::for_each_precision(&dag); let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION); @@ -526,10 +526,10 @@ pub mod tests { }; let mut dag = unparametrized::Dag::new(); let input = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); - let change_part2 = dag.add_change_partition(lut2, None, Some(&tfhers_partition)); + let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition)); let partitions = partitionning(&dag); assert!(partitions.nb_partitions == 3); @@ -585,10 +585,10 @@ pub mod tests { }; let mut dag = unparametrized::Dag::new(); let input = dag.add_input(16, Shape::number()); - let change_part1 = dag.add_change_partition(input, Some(&tfhers_partition_in), None); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None); let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); - let change_part2 = dag.add_change_partition(lut2, None, Some(&tfhers_partition_out)); + let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition_out.clone())); let p_cut = PartitionCut::for_each_precision(&dag); let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION);