Skip to content

Commit

Permalink
feat(optimizer): partition with external partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 9, 2024
1 parent c3c0976 commit 67e18ab
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::iter::{empty, once};
use std::ops::Deref;

use crate::dag::operator::tensor::{ClearTensor, Shape};
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;

use super::DotKind;

Expand Down Expand Up @@ -106,6 +107,8 @@ pub enum Operator {
},
ChangePartition {
input: OperatorIndex,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
},
}

Expand All @@ -118,7 +121,7 @@ impl Operator {
Self::UnsafeCast { input, .. }
| Self::Lut { input, .. }
| Self::Round { input, .. }
| Self::ChangePartition { input } => Box::new(once(input)),
| Self::ChangePartition { input, .. } => Box::new(once(input)),
}
}
}
Expand Down Expand Up @@ -194,8 +197,22 @@ impl fmt::Display for Operator {
} => {
write!(f, "ROUND[%{}] : u{out_precision}", input.0)?;
}
Self::ChangePartition { input } => {
write!(f, "ChangePartition[%{}]", input.0)?;
Self::ChangePartition {
input,
src_partition,
dst_partition,
} => {
write!(f, "CHANGE_PARTITION[%{}] : {{", input.0)?;
if let Some(partition) = src_partition {
write!(f, "src_partition: {}", partition.name)?;
}
if let Some(partition) = dst_partition {
if src_partition.is_some() {
write!(f, ", ")?;
}
write!(f, "dst_partition: {}", partition.name)?;
}
write!(f, "}}")?;
}
}
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. }
| Operator::ChangePartition { input } => input.0 = old_index_to_new[input.0],
| Operator::ChangePartition { input, .. } => input.0 = old_index_to_new[input.0],
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
for input in inputs {
input.0 = old_index_to_new[input.0];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
};
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;
use std::{
collections::{HashMap, HashSet},
fmt,
Expand Down Expand Up @@ -262,9 +263,22 @@ impl<'dag> DagBuilder<'dag> {
pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
location: Location,
) -> OperatorIndex {
self.add_operator(Operator::ChangePartition { input }, location)
assert!(
src_partition.is_some() || dst_partition.is_some(),
"change_partition: src or dest partition need to be set"
);
self.add_operator(
Operator::ChangePartition {
input,
src_partition: src_partition.cloned(),
dst_partition: dst_partition.cloned(),
},
location,
)
}

pub fn add_round_op(
Expand Down Expand Up @@ -425,7 +439,7 @@ impl<'dag> DagBuilder<'dag> {
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. }
| Operator::ChangePartition { input } => self.dag.out_shapes[input.0].clone(),
| Operator::ChangePartition { input, .. } => self.dag.out_shapes[input.0].clone(),
Operator::Dot {
kind: DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor,
..
Expand Down Expand Up @@ -460,7 +474,7 @@ impl<'dag> DagBuilder<'dag> {
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
self.dag.out_precisions[inputs[0].0]
}
Operator::ChangePartition { input } => self.dag.out_precisions[input.0],
Operator::ChangePartition { input, .. } => self.dag.out_precisions[input.0],
}
}
}
Expand Down Expand Up @@ -587,6 +601,20 @@ impl Dag {
.add_lut(input, table, out_precision, Location::Unknown)
}

pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
) -> OperatorIndex {
self.builder(DEFAULT_CIRCUIT).add_change_partition(
input,
src_partition,
dst_partition,
Location::Unknown,
)
}

pub fn add_dot(
&mut self,
inputs: impl Into<Vec<OperatorIndex>>,
Expand Down Expand Up @@ -862,18 +890,21 @@ mod tests {
#[allow(clippy::many_single_char_names)]
fn graph_builder() {
let mut graph = Dag::new();
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
};
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number(), Location::Unknown);
let b = builder.add_input(1, Shape::number(), Location::Unknown);
let c = builder.add_dot([a, b], [1, 1], Location::Unknown);
let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown);
let _d = builder.add_change_partition(d, Location::Unknown);
let _d = builder.add_change_partition(d, Some(&tfhers_part), None, Location::Unknown);
let mut builder = graph.builder("main2");
let e = builder.add_input(2, Shape::number(), Location::Unknown);
let f = builder.add_input(2, Shape::number(), Location::Unknown);
let g = builder.add_dot([e, f], [2, 2], Location::Unknown);
let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown);
let _h = builder.add_change_partition(h, Location::Unknown);
let _h = builder.add_change_partition(h, None, Some(&tfhers_part), Location::Unknown);
graph.tag_operator_as_output(c);
}

Expand Down Expand Up @@ -909,7 +940,11 @@ mod tests {

let lut2 = builder.add_lut(dot, FunctionTable::UNKWOWN, 2, Location::Unknown);

let change_part = builder.add_change_partition(lut2, Location::Unknown);
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
};
let change_part =
builder.add_change_partition(lut2, Some(&tfhers_part), None, Location::Unknown);

let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part];
for (expected_i, op_index) in ops_index.iter().enumerate() {
Expand Down Expand Up @@ -959,7 +994,11 @@ mod tests {
table: FunctionTable::UNKWOWN,
out_precision: 2,
},
Operator::ChangePartition { input: lut2 }
Operator::ChangePartition {
input: lut2,
src_partition: Some(tfhers_part.clone()),
dst_partition: None
}
]
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,14 @@ impl VariancedDag {
acc + var[operator.partition().instruction_partition].clone()
* square(*weight as f64)
}),
Operator::UnsafeCast { .. } => {
Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => {
operator.get_inputs_iter().next().unwrap().variance()
[operator.partition().instruction_partition]
.clone()
}
Operator::Round { .. } => {
unreachable!("Round should have been either expanded or integrated to a lut")
}
Operator::ChangePartition { .. } => {
// TODO
todo!("TODO")
}
};
// We add the noise for the transitions to alternative representations
operator
Expand Down Expand Up @@ -1326,6 +1322,6 @@ pub mod tests {
let p_cut = PartitionCut::from_precisions(&precisions);
let dag =
super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION).unwrap();
assert!(dag.nb_partitions == p_cut.p_cut.len() + 1);
assert!(dag.nb_partitions == p_cut.n_partitions());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ fn test_partition_chain(decreasing: bool) {
let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap();
let nb_partitions = sol.macro_params.len();
assert!(
nb_partitions == (p_cut.p_cut.len() + 1),
nb_partitions == p_cut.n_partitions(),
"bad nb partitions {} {p_cut}",
sol.macro_params.len()
);
Expand Down
Loading

0 comments on commit 67e18ab

Please sign in to comment.