diff --git a/Cargo.lock b/Cargo.lock index 10f75503f..92a0508eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -385,6 +385,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "merging-iterator" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eda55172050e028dfd9d0b4335cbdbeaf81adf2b1f2dc463285e819e9a67eae" + [[package]] name = "num-traits" version = "0.2.17" @@ -458,6 +464,7 @@ dependencies = [ "criterion", "hex", "itertools 0.12.0", + "merging-iterator", "num-traits", "rand", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index bb669c5bc..60c1437e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ hex = "0.4.3" itertools = "0.12.0" num-traits = "0.2.17" thiserror = "1.0.56" +merging-iterator = "1.3.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } diff --git a/src/commitment_scheme/mixed_degree_merkle_tree.rs b/src/commitment_scheme/mixed_degree_merkle_tree.rs index 69caa9783..cac0ee6ef 100644 --- a/src/commitment_scheme/mixed_degree_merkle_tree.rs +++ b/src/commitment_scheme/mixed_degree_merkle_tree.rs @@ -1,4 +1,7 @@ +use std::iter::Peekable; + use itertools::Itertools; +use merging_iterator::MergeIter; use super::hasher::Hasher; use super::merkle_input::MerkleTreeInput; @@ -96,7 +99,8 @@ where } // Queries should be a query struct that supports queries at multiple layers. - pub fn decommit(&self, _queries: Vec>) -> MixedDecommitment { + // TODO(Ohad): introduce a proper query struct, then deprecate 'drain' usage and accepting vecs. + pub fn decommit(&self, mut _queries_per_column: Vec>) -> MixedDecommitment { todo!() } @@ -140,9 +144,68 @@ where self.multi_layers[layer_index].config.sub_tree_height } + // Generates the witness of a single layer and adds it to the decommitment. + // 'previous_layer_indices' - node indices that are part of the witness for a query below . + // 'queries_to_layer'- queries to columns at this layer. + #[allow(dead_code)] + fn decommit_single_layer( + &self, + layer_depth: usize, + queries_to_layer: impl Iterator> + Clone, + mut previous_layers_indices: Peekable + Clone>, + decommitment: &mut MixedDecommitment, + ) -> Vec { + let directly_queried_node_indices = + queried_nodes_in_layer(queries_to_layer.clone(), &self.input, layer_depth); + let mut index_value_iterator = directly_queried_node_indices + .iter() + .copied() + .zip(self.layer_felt_witnesses_and_queried_elements( + layer_depth, + queries_to_layer, + directly_queried_node_indices.iter().copied(), + )) + .peekable(); + let mut node_indices = MergeIter::new( + directly_queried_node_indices.iter().copied(), + previous_layers_indices.clone().map(|q| q / 2), + ) + .collect_vec(); + node_indices.dedup(); + + for &node_index in node_indices.iter() { + match previous_layers_indices.next_if(|&q| q / 2 == node_index) { + None if layer_depth < self.height() => { + // If the node is not a direct query, include both hashes. + let (l_hash, r_hash) = self.child_hashes(node_index, layer_depth); + decommitment.hashes.push(l_hash); + decommitment.hashes.push(r_hash); + } + Some(q) + if previous_layers_indices + .next_if(|&next_q| next_q ^ 1 == q) + .is_none() => + { + decommitment.hashes.push(self.sibling_hash(q, layer_depth)); + } + _ => {} + } + + if let Some((_, (witness, queried))) = + index_value_iterator.next_if(|(n, _)| *n == node_index) + { + decommitment.witness_elements.extend(witness); + decommitment.queried_values.extend(queried); + } else { + let injected_elements = self.input.get_injected_elements(layer_depth, node_index); + decommitment.witness_elements.extend(injected_elements); + } + } + node_indices + } + // Returns the felt witnesses and queried elements for the given node indices in the specified // layer. Assumes that the queries & node indices are sorted in ascending order. - #[allow(dead_code)] fn layer_felt_witnesses_and_queried_elements( &self, layer_depth: usize, @@ -178,13 +241,26 @@ where witnesses_and_queried_values_by_node } + + #[allow(dead_code)] + fn sibling_hash(&self, query: usize, layer_depth: usize) -> H::Hash { + self.get_hash_at(layer_depth, query ^ 1) + } + + #[allow(dead_code)] + fn child_hashes(&self, node_index: usize, layer_depth: usize) -> (H::Hash, H::Hash) { + ( + self.get_hash_at(layer_depth, node_index * 2), + self.get_hash_at(layer_depth, node_index * 2 + 1), + ) + } } // Translates queries of the form to the form // Input queries are per column, i.e queries[0] is a vector of queries for the first column that was // inserted to the tree's input in that layer. #[allow(dead_code)] -fn queried_node_indices_in_layer<'a>( +fn queried_nodes_in_layer<'a>( queries: impl Iterator>, input: &MerkleTreeInput<'_, impl Field>, layer_depth: usize, @@ -213,9 +289,7 @@ fn queried_node_indices_in_layer<'a>( mod tests { use std::vec; - use super::{ - queried_node_indices_in_layer, MixedDegreeMerkleTree, MixedDegreeMerkleTreeConfig, - }; + use super::{queried_nodes_in_layer, MixedDegreeMerkleTree, MixedDegreeMerkleTreeConfig}; use crate::commitment_scheme::blake2_hash::Blake2sHasher; use crate::commitment_scheme::blake3_hash::Blake3Hasher; use crate::commitment_scheme::hasher::Hasher; @@ -372,7 +446,7 @@ mod tests { let column_queries_at_depth = queries .drain(..n_columns_injected_at_depth) .collect::>(); - super::queried_node_indices_in_layer(column_queries_at_depth.iter(), input, i) + super::queried_nodes_in_layer(column_queries_at_depth.iter(), input, i) }) .collect::>>() } @@ -444,13 +518,13 @@ mod tests { third_column_queries, ]; - let node_indices = queried_node_indices_in_layer(queries.iter().take(2), &tree.input, 4); + let node_indices = queried_nodes_in_layer(queries.iter().take(2), &tree.input, 4); let w4 = tree.layer_felt_witnesses_and_queried_elements( 4, queries[..2].iter(), node_indices.iter().copied(), ); - let node_indices = queried_node_indices_in_layer(queries.iter().skip(2), &tree.input, 3); + let node_indices = queried_nodes_in_layer(queries.iter().skip(2), &tree.input, 3); let w3 = tree.layer_felt_witnesses_and_queried_elements( 4, queries[2..4].iter(),