diff --git a/Cargo.lock b/Cargo.lock index 9b2800d..65dbda8 100755 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,11 +57,17 @@ dependencies = [ "wasi", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "indoc" -version = "1.0.9" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "libc" @@ -105,9 +111,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "ordered-float" -version = "3.9.1" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a54938017eacd63036332b4ae5c8a49fc8c0c1d6d629893057e4f13609edd06" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" dependencies = [ "num-traits", ] @@ -135,6 +141,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -152,15 +164,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.19.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +checksum = "a7a8b1990bd018761768d5e608a13df8bd1ac5f678456e0f301bb93e5f3ea16b" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -169,9 +182,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.19.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +checksum = "650dca34d463b6cdbdb02b1d71bfd6eb6b6816afc708faebb3bac1380ff4aef7" dependencies = [ "once_cell", "target-lexicon", @@ -179,9 +192,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.19.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +checksum = "09a7da8fc04a8a2084909b59f29e1b8474decac98b951d77b80b26dc45f046ad" dependencies = [ "libc", "pyo3-build-config", @@ -189,9 +202,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.19.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +checksum = "4b8a199fce11ebb28e3569387228836ea98110e43a804a530a9fd83ade36d513" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -201,11 +214,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.19.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +checksum = "93fbbfd7eb553d10036513cb122b888dcd362a945a00b06c165f2ab480d4cc3b" dependencies = [ + "heck", "proc-macro2", + "pyo3-build-config", "quote", "syn", ] @@ -278,9 +293,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" [[package]] name = "syn" -version = "1.0.109" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -301,9 +316,9 @@ checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" [[package]] name = "unindent" -version = "0.1.11" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "wasi" diff --git a/Cargo.toml b/Cargo.toml index d96f24e..9de36cb 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ crate-type = ["cdylib"] [dependencies] bit-set = "0.5" -pyo3 = "0.19" -ordered-float = "3.9" +ordered-float = "4.2" +pyo3 = "0.21" rand = "0.8" rustc-hash = "1.1" diff --git a/pyproject.toml b/pyproject.toml index fe55625..a9ef0c1 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ authors = [ ] [build-system] -requires = ["maturin>=0.15,<0.16"] +requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [tool.maturin] diff --git a/src/lib.rs b/src/lib.rs index 9dc5416..d12a7ff 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,9 @@ use bit_set::BitSet; use ordered_float::OrderedFloat; use pyo3::prelude::*; use rand::Rng; +use rand::SeedableRng; use rustc_hash::FxHashMap; -use std::collections::{BTreeSet, BinaryHeap}; +use std::collections::{BTreeSet, BinaryHeap, HashSet}; use std::f32; use FxHashMap as Dict; @@ -23,6 +24,7 @@ type BitPath = Vec<(Subgraph, Subgraph)>; type SubContraction = (Legs, Score, BitPath); /// helper struct to build contractions from bottom up +#[derive(Clone)] struct ContractionProcessor { nodes: Dict, edges: Dict>, @@ -30,6 +32,8 @@ struct ContractionProcessor { sizes: Vec, ssa: Node, ssa_path: SSAPath, + track_flops: bool, + flops: Score, } /// given log(x) and log(y) compute log(x + y), without exponentiating both @@ -94,6 +98,21 @@ fn compute_size(legs: &Legs, sizes: &Vec) -> Score { legs.iter().map(|&(ix, _)| sizes[ix as usize]).sum() } +fn compute_flops(ilegs: &Legs, jlegs: &Legs, sizes: &Vec) -> Score { + let mut flops: Score = 0.0; + let mut seen: HashSet = HashSet::with_capacity(ilegs.len()); + for &(ix, _) in ilegs { + seen.insert(ix); + flops += sizes[ix as usize]; + } + for (ix, _) in jlegs { + if !seen.contains(ix) { + flops += sizes[*ix as usize]; + } + } + flops +} + fn is_simplifiable(legs: &Legs, appearances: &Vec) -> bool { let mut prev_ix = Node::MAX; for &(ix, ix_count) in legs { @@ -131,6 +150,7 @@ impl ContractionProcessor { inputs: Vec>, output: Vec, size_dict: Dict, + track_flops: bool, ) -> ContractionProcessor { let mut nodes: Dict = Dict::default(); let mut edges: Dict> = Dict::default(); @@ -149,7 +169,7 @@ impl ContractionProcessor { indmap.insert(ind, c); edges.insert(c, std::iter::once(i as Node).collect()); appearances.push(1); - sizes.push(f32::log(size_dict[&ind] as f32, 2.0)); + sizes.push(f32::ln(size_dict[&ind] as f32)); legs.push((c, 1)); c += 1; } @@ -170,6 +190,7 @@ impl ContractionProcessor { let ssa = nodes.len() as Node; let ssa_path: SSAPath = Vec::with_capacity(2 * ssa as usize - 1); + let flops: Score = 0.0; ContractionProcessor { nodes, @@ -178,6 +199,8 @@ impl ContractionProcessor { sizes, ssa, ssa_path, + track_flops, + flops, } } @@ -225,7 +248,9 @@ impl ContractionProcessor { for (ix, _) in &legs { self.edges .entry(*ix) - .and_modify(|nodes| {nodes.insert(i);}) + .and_modify(|nodes| { + nodes.insert(i); + }) .or_insert(std::iter::once(i as Node).collect()); } self.nodes.insert(i, legs); @@ -236,12 +261,27 @@ impl ContractionProcessor { fn contract_nodes(&mut self, i: Node, j: Node) -> Node { let ilegs = self.pop_node(i); let jlegs = self.pop_node(j); + if self.track_flops { + self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes)); + } let new_legs = compute_legs(&ilegs, &jlegs, &self.appearances); let k = self.add_node(new_legs); self.ssa_path.push(vec![i, j]); k } + /// contract two nodes (which we already know the legs for), return the new node id + fn contract_nodes_given_legs(&mut self, i: Node, j: Node, new_legs: Legs) -> Node { + let ilegs = self.pop_node(i); + let jlegs = self.pop_node(j); + if self.track_flops { + self.flops = logadd(self.flops, compute_flops(&ilegs, &jlegs, &self.sizes)); + } + let k = self.add_node(new_legs); + self.ssa_path.push(vec![i, j]); + k + } + /// find any indices that appear in all terms and just remove/ignore them fn simplify_batch(&mut self) { let mut ix_to_remove = Vec::new(); @@ -366,13 +406,27 @@ impl ContractionProcessor { } /// greedily optimize the contraction order of all terms - fn optimize_greedy(&mut self, costmod: Option, temperature: Option) { - let mut rng = rand::thread_rng(); + fn optimize_greedy( + &mut self, + costmod: Option, + temperature: Option, + seed: Option, + ) { let coeff_t = temperature.unwrap_or(0.0); let log_coeff_a = f32::ln(costmod.unwrap_or(1.0)); + let mut rng = if coeff_t != 0.0 { + Some(match seed { + Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), + None => rand::rngs::StdRng::from_entropy(), + }) + } else { + // zero temp - no need for rng + None + }; + let mut local_score = |sa: Score, sb: Score, sab: Score| -> Score { - let gumbel = if coeff_t != 0.0 { + let gumbel = if let Some(rng) = &mut rng { coeff_t * -f32::ln(-f32::ln(rng.gen())) } else { 0.0 as f32 @@ -424,11 +478,7 @@ impl ContractionProcessor { } // perform contraction: - // we already have the legs, so don't call contract_nodes - self.pop_node(i); - self.pop_node(j); - let k = self.add_node(klegs.clone()); - self.ssa_path.push(vec![i, j]); + let k = self.contract_nodes_given_legs(i, j, klegs.clone()); node_sizes.insert(k, ksize); for l in self.neighbors(k) { @@ -800,7 +850,6 @@ impl ContractionProcessor { // --------------------------- PYTHON FUNCTIONS ---------------------------- // #[pyfunction] -#[pyo3()] fn ssa_to_linear(ssa_path: SSAPath, n: Option) -> SSAPath { let n = match n { Some(n) => n, @@ -828,18 +877,16 @@ fn ssa_to_linear(ssa_path: SSAPath, n: Option) -> SSAPath { } #[pyfunction] -#[pyo3()] fn find_subgraphs( inputs: Vec>, output: Vec, size_dict: Dict, ) -> Vec> { - let cp = ContractionProcessor::new(inputs, output, size_dict); + let cp = ContractionProcessor::new(inputs, output, size_dict, false); cp.subgraphs() } #[pyfunction] -#[pyo3()] fn optimize_simplify( inputs: Vec>, output: Vec, @@ -847,7 +894,7 @@ fn optimize_simplify( use_ssa: Option, ) -> SSAPath { let n = inputs.len(); - let mut cp = ContractionProcessor::new(inputs, output, size_dict); + let mut cp = ContractionProcessor::new(inputs, output, size_dict, false); cp.simplify(); if use_ssa.unwrap_or(false) { cp.ssa_path @@ -857,36 +904,94 @@ fn optimize_simplify( } #[pyfunction] -#[pyo3()] fn optimize_greedy( + py: Python, inputs: Vec>, output: Vec, size_dict: Dict, costmod: Option, temperature: Option, + seed: Option, simplify: Option, use_ssa: Option, ) -> Vec> { - let n = inputs.len(); - let mut cp = ContractionProcessor::new(inputs, output, size_dict); - if simplify.unwrap_or(true) { - // perform simplifications - cp.simplify(); - } - // greddily contract each connected subgraph - cp.optimize_greedy(costmod, temperature); - // optimize any remaining disconnected terms - cp.optimize_remaining_by_size(); - if use_ssa.unwrap_or(false) { - cp.ssa_path - } else { - ssa_to_linear(cp.ssa_path, Some(n)) - } + py.allow_threads(|| { + let n = inputs.len(); + let mut cp = ContractionProcessor::new(inputs, output, size_dict, false); + if simplify.unwrap_or(true) { + // perform simplifications + cp.simplify(); + } + // greedily contract each connected subgraph + cp.optimize_greedy(costmod, temperature, seed); + // optimize any remaining disconnected terms + cp.optimize_remaining_by_size(); + if use_ssa.unwrap_or(false) { + cp.ssa_path + } else { + ssa_to_linear(cp.ssa_path, Some(n)) + } + }) +} + +#[pyfunction] +fn optimize_random_greedy_track_flops( + py: Python, + inputs: Vec>, + output: Vec, + size_dict: Dict, + ntrials: usize, + costmod: Option, + temperature: Option, + seed: Option, + simplify: Option, + use_ssa: Option, +) -> (Vec>, Score) { + py.allow_threads(|| { + let temperature = temperature.unwrap_or(0.01); + let mut rng = match seed { + Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), + None => rand::rngs::StdRng::from_entropy(), + }; + let seeds = (0..ntrials).map(|_| rng.gen()).collect::>(); + + let n: usize = inputs.len(); + // construct processor and perform simplifications once + let mut cp0 = ContractionProcessor::new(inputs, output, size_dict, true); + if simplify.unwrap_or(true) { + cp0.simplify(); + } + + let mut best_path = None; + let mut best_flops = f32::INFINITY; + + for seed in seeds { + let mut cp = cp0.clone(); + // greedily contract each connected subgraph + cp.optimize_greedy(costmod, Some(temperature), Some(seed)); + // optimize any remaining disconnected terms + cp.optimize_remaining_by_size(); + + if cp.flops < best_flops { + best_flops = cp.flops; + best_path = Some(cp.ssa_path); + } + } + + // convert to base 10 for easier comparison + best_flops *= f32::consts::LOG10_E; + + if use_ssa.unwrap_or(false) { + (best_path.unwrap(), best_flops) + } else { + (ssa_to_linear(best_path.unwrap(), Some(n)), best_flops) + } + }) } #[pyfunction] -#[pyo3()] fn optimize_optimal( + py: Python, inputs: Vec>, output: Vec, size_dict: Dict, @@ -896,30 +1001,33 @@ fn optimize_optimal( simplify: Option, use_ssa: Option, ) -> Vec> { - let n = inputs.len(); - let mut cp = ContractionProcessor::new(inputs, output, size_dict); - if simplify.unwrap_or(true) { - // perform simplifications - cp.simplify(); - } - // optimally contract each connected subgraph - cp.optimize_optimal(minimize, cost_cap, search_outer); - // optimize any remaining disconnected terms - cp.optimize_remaining_by_size(); - if use_ssa.unwrap_or(false) { - cp.ssa_path - } else { - ssa_to_linear(cp.ssa_path, Some(n)) - } + py.allow_threads(|| { + let n = inputs.len(); + let mut cp = ContractionProcessor::new(inputs, output, size_dict, false); + if simplify.unwrap_or(true) { + // perform simplifications + cp.simplify(); + } + // optimally contract each connected subgraph + cp.optimize_optimal(minimize, cost_cap, search_outer); + // optimize any remaining disconnected terms + cp.optimize_remaining_by_size(); + if use_ssa.unwrap_or(false) { + cp.ssa_path + } else { + ssa_to_linear(cp.ssa_path, Some(n)) + } + }) } /// A Python module implemented in Rust. #[pymodule] -fn cotengrust(_py: Python, m: &PyModule) -> PyResult<()> { +fn cotengrust(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?; m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?; m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?; m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?; + m.add_function(wrap_pyfunction!(optimize_random_greedy_track_flops, m)?)?; m.add_function(wrap_pyfunction!(optimize_optimal, m)?)?; Ok(()) } diff --git a/tests/test_cotengrust.py b/tests/test_cotengrust.py index db386ac..04baf9c 100644 --- a/tests/test_cotengrust.py +++ b/tests/test_cotengrust.py @@ -197,18 +197,39 @@ def test_basic_rand(seed, which): @requires_cotengra def test_optimal_lattice_eq(): inputs, output, _, size_dict = ctg.utils.lattice_equation( - [4, 5], d_max=3, seed=42 + [4, 5], d_max=2, seed=42 ) path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops') tree = ctg.ContractionTree.from_path( inputs, output, size_dict, path=path ) - assert tree.contraction_cost() == 3628 + assert tree.is_complete() + assert tree.contraction_cost() == 964 path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size') assert all(len(con) <= 2 for con in path) tree = ctg.ContractionTree.from_path( inputs, output, size_dict, path=path ) - assert tree.contraction_width() == pytest.approx(6.754887502163468) + assert tree.contraction_width() == pytest.approx(5) + + +@requires_cotengra +def test_optimize_random_greedy_log_flops(): + inputs, output, _, size_dict = ctg.utils.lattice_equation( + [10, 10], d_max=3, seed=42 + ) + + path, cost1 = ctgr.optimize_random_greedy_track_flops( + inputs, output, size_dict, ntrials=4, seed=42 + ) + _, cost2 = ctgr.optimize_random_greedy_track_flops( + inputs, output, size_dict, ntrials=4, seed=42 + ) + assert cost1 == cost2 + tree = ctg.ContractionTree.from_path( + inputs, output, size_dict, path=path + ) + assert tree.is_complete() + assert tree.contraction_cost(log=10) == pytest.approx(cost1) \ No newline at end of file