From 1ebc2beeec84a8255fd49aee41507f6f65bfa4e1 Mon Sep 17 00:00:00 2001
From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com>
Date: Tue, 26 Mar 2024 09:16:43 +0200
Subject: [PATCH] Simple mixed merkle tree (#525)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This change is [](https://reviewable.io/reviews/starkware-libs/stwo/525)
---
src/commitment_scheme/blake2_merkle.rs | 193 +++++++++++++++++++++++
src/commitment_scheme/mod.rs | 4 +
src/commitment_scheme/ops.rs | 40 +++++
src/commitment_scheme/prover.rs | 92 +++++++++++
src/commitment_scheme/verifier.rs | 210 +++++++++++++++++++++++++
5 files changed, 539 insertions(+)
create mode 100644 src/commitment_scheme/blake2_merkle.rs
create mode 100644 src/commitment_scheme/ops.rs
create mode 100644 src/commitment_scheme/prover.rs
create mode 100644 src/commitment_scheme/verifier.rs
diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs
new file mode 100644
index 000000000..e3e48ff64
--- /dev/null
+++ b/src/commitment_scheme/blake2_merkle.rs
@@ -0,0 +1,193 @@
+use itertools::Itertools;
+use num_traits::Zero;
+
+use super::blake2s_ref::compress;
+use super::ops::{MerkleHasher, MerkleOps};
+use crate::core::backend::CPUBackend;
+use crate::core::fields::m31::BaseField;
+
+pub struct Blake2Hasher;
+impl MerkleHasher for Blake2Hasher {
+ type Hash = [u32; 8];
+
+ fn hash_node(
+ children_hashes: Option<(Self::Hash, Self::Hash)>,
+ column_values: &[BaseField],
+ ) -> Self::Hash {
+ let mut state = [0; 8];
+ if let Some((left, right)) = children_hashes {
+ state = compress(
+ state,
+ unsafe { std::mem::transmute([left, right]) },
+ 0,
+ 0,
+ 0,
+ 0,
+ );
+ }
+ let rem = 15 - ((column_values.len() + 15) % 16);
+ let padded_values = column_values
+ .iter()
+ .copied()
+ .chain(std::iter::repeat(BaseField::zero()).take(rem));
+ for chunk in padded_values.array_chunks::<16>() {
+ state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
+ }
+ state
+ }
+}
+
+impl MerkleOps for CPUBackend {
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Vec<[u32; 8]>>,
+ columns: &[&Vec],
+ ) -> Vec<[u32; 8]> {
+ (0..(1 << log_size))
+ .map(|i| {
+ Blake2Hasher::hash_node(
+ prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
+ &columns.iter().map(|column| column[i]).collect_vec(),
+ )
+ })
+ .collect()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use itertools::Itertools;
+ use num_traits::Zero;
+ use rand::rngs::StdRng;
+ use rand::{Rng, SeedableRng};
+
+ use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
+ use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
+ use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
+ use crate::core::backend::CPUBackend;
+ use crate::core::fields::m31::BaseField;
+
+ type TestData = (
+ Vec,
+ MerkleDecommitment,
+ Vec<(u32, Vec)>,
+ MerkleTreeVerifier,
+ );
+ fn prepare_merkle() -> TestData {
+ const N_COLS: usize = 400;
+ const N_QUERIES: usize = 7;
+
+ let rng = &mut StdRng::seed_from_u64(0);
+ let log_sizes = (0..N_COLS)
+ .map(|_| rng.gen_range(6..9))
+ .sorted()
+ .rev()
+ .collect_vec();
+ let max_log_size = *log_sizes.iter().max().unwrap();
+ let cols = log_sizes
+ .iter()
+ .map(|&log_size| {
+ (0..(1 << log_size))
+ .map(|_| BaseField::from(rng.gen_range(0..(1 << 30))))
+ .collect_vec()
+ })
+ .collect_vec();
+ let merkle = MerkleProver::::commit(cols.iter().collect_vec());
+
+ let queries = (0..N_QUERIES)
+ .map(|_| rng.gen_range(0..(1 << max_log_size)))
+ .sorted()
+ .dedup()
+ .collect_vec();
+ let decommitment = merkle.decommit(queries.clone());
+ let values = cols
+ .iter()
+ .map(|col| {
+ let layer_queries = queries
+ .iter()
+ .map(|&q| q >> (max_log_size - col.len().ilog2()))
+ .dedup();
+ layer_queries.map(|q| col[q]).collect_vec()
+ })
+ .collect_vec();
+ let values = log_sizes.into_iter().zip(values).collect_vec();
+
+ let verifier = MerkleTreeVerifier {
+ root: merkle.root(),
+ };
+ (queries, decommitment, values, verifier)
+ }
+
+ #[test]
+ fn test_merkle_success() {
+ let (queries, decommitment, values, verifier) = prepare_merkle();
+
+ verifier.verify(queries, values, decommitment).unwrap();
+ }
+
+ #[test]
+ fn test_merkle_invalid_witness() {
+ let (queries, mut decommitment, values, verifier) = prepare_merkle();
+ decommitment.witness[20] = [0; 8];
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::RootMismatch
+ );
+ }
+
+ #[test]
+ fn test_merkle_invalid_value() {
+ let (queries, decommitment, mut values, verifier) = prepare_merkle();
+ values[3].1[6] = BaseField::zero();
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::RootMismatch
+ );
+ }
+
+ #[test]
+ fn test_merkle_witness_too_short() {
+ let (queries, mut decommitment, values, verifier) = prepare_merkle();
+ decommitment.witness.pop();
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::WitnessTooShort
+ );
+ }
+
+ #[test]
+ fn test_merkle_column_values_too_long() {
+ let (queries, decommitment, mut values, verifier) = prepare_merkle();
+ values[3].1.push(BaseField::zero());
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::ColumnValuesTooLong
+ );
+ }
+
+ #[test]
+ fn test_merkle_column_values_too_short() {
+ let (queries, decommitment, mut values, verifier) = prepare_merkle();
+ values[3].1.pop();
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::ColumnValuesTooShort
+ );
+ }
+
+ #[test]
+ fn test_merkle_witness_too_long() {
+ let (queries, mut decommitment, values, verifier) = prepare_merkle();
+ decommitment.witness.push([0; 8]);
+
+ assert_eq!(
+ verifier.verify(queries, values, decommitment).unwrap_err(),
+ MerkleVerificationError::WitnessTooLong
+ );
+ }
+}
diff --git a/src/commitment_scheme/mod.rs b/src/commitment_scheme/mod.rs
index 23001ed95..c7108e08c 100644
--- a/src/commitment_scheme/mod.rs
+++ b/src/commitment_scheme/mod.rs
@@ -1,4 +1,5 @@
pub mod blake2_hash;
+pub mod blake2_merkle;
pub mod blake2s_ref;
pub mod blake3_hash;
pub mod hasher;
@@ -8,4 +9,7 @@ pub mod merkle_multilayer;
pub mod merkle_tree;
pub mod mixed_degree_decommitment;
pub mod mixed_degree_merkle_tree;
+pub mod ops;
+pub mod prover;
pub mod utils;
+pub mod verifier;
diff --git a/src/commitment_scheme/ops.rs b/src/commitment_scheme/ops.rs
new file mode 100644
index 000000000..65660f887
--- /dev/null
+++ b/src/commitment_scheme/ops.rs
@@ -0,0 +1,40 @@
+use crate::core::backend::{Col, ColumnOps};
+use crate::core::fields::m31::BaseField;
+
+/// A Merkle node hash is a hash of:
+/// [left_child_hash, right_child_hash], column0_value, column1_value, ...
+/// "[]" denotes optional values.
+/// The largest Merkle layer has no left and right child hashes. The rest of the layers have
+/// children hashes.
+/// At each layer, the tree may have multiple columns of the same length as the layer.
+/// Each node in that layer contains one value from each column.
+pub trait MerkleHasher {
+ type Hash: Clone + Eq + std::fmt::Debug;
+ /// Hashes a single Merkle node. See [MerkleHasher] for more details.
+ fn hash_node(
+ children_hashes: Option<(Self::Hash, Self::Hash)>,
+ column_values: &[BaseField],
+ ) -> Self::Hash;
+}
+
+/// Trait for performing Merkle operations on a commitment scheme.
+pub trait MerkleOps: ColumnOps + ColumnOps {
+ /// Commits on an entire layer of the Merkle tree.
+ /// See [MerkleHasher] for more details.
+ ///
+ /// The layer has 2^`log_size` nodes that need to be hashed. The topmost layer has 1 node,
+ /// which is a hash of 2 children and some columns.
+ ///
+ /// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer.
+ /// That layer is assumed to have 2^(`log_size`+1) nodes.
+ ///
+ /// `columns` are the extra columns that need to be hashed in each node.
+ /// They are assumed to be of size 2^`log_size`.
+ ///
+ /// Returns the next Merkle layer hashes.
+ fn commit_on_layer(
+ log_size: u32,
+ prev_layer: Option<&Col>,
+ columns: &[&Col],
+ ) -> Col;
+}
diff --git a/src/commitment_scheme/prover.rs b/src/commitment_scheme/prover.rs
new file mode 100644
index 000000000..8823213e2
--- /dev/null
+++ b/src/commitment_scheme/prover.rs
@@ -0,0 +1,92 @@
+use std::cmp::Reverse;
+
+use itertools::Itertools;
+
+use super::ops::{MerkleHasher, MerkleOps};
+use crate::core::backend::{Col, Column};
+use crate::core::fields::m31::BaseField;
+
+pub struct MerkleProver, H: MerkleHasher> {
+ /// Layers of the Merkle tree.
+ /// The first layer is the largest column.
+ /// The last layer is the root.
+ /// See [MerkleOps::commit_on_layer] for more details.
+ pub layers: Vec>,
+}
+/// The MerkleProver struct represents a prover for a Merkle commitment scheme.
+/// It is generic over the types `B` and `H`, which represent the Merkle operations and Merkle
+/// hasher respectively.
+impl, H: MerkleHasher> MerkleProver {
+ /// Commits to columns.
+ /// Columns must be of power of 2 sizes and sorted in descending order.
+ ///
+ /// # Arguments
+ ///
+ /// * `columns` - A vector of references to columns.
+ ///
+ /// # Panics
+ ///
+ /// This function will panic if the columns are not sorted in descending order or if the columns
+ /// vector is empty.
+ ///
+ /// # Returns
+ ///
+ /// A new instance of `MerkleProver` with the committed layers.
+ pub fn commit(columns: Vec<&Col>) -> Self {
+ // Check that columns are of descending order.
+ assert!(!columns.is_empty());
+ assert!(columns.is_sorted_by_key(|c| Reverse(c.len())));
+
+ let mut columns = &mut columns.into_iter().peekable();
+ let mut layers: Vec> = Vec::new();
+
+ let max_log_size = columns.peek().unwrap().len().ilog2();
+ for log_size in (0..=max_log_size).rev() {
+ // Take columns of the current log_size.
+ let layer_columns = (&mut columns)
+ .take_while(|column| column.len().ilog2() == log_size)
+ .collect_vec();
+
+ layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns));
+ }
+ Self { layers }
+ }
+
+ /// Decommits to columns on the given queries.
+ /// Queries are given as indices to the largest column.
+ ///
+ /// # Arguments
+ ///
+ /// * `queries` - A vector of query indices to the largest column.
+ ///
+ /// # Returns
+ ///
+ /// A `Decommitment` struct containing the witness.
+ pub fn decommit(&self, mut queries: Vec) -> MerkleDecommitment {
+ let mut witness = Vec::new();
+ for layer in &self.layers[..self.layers.len() - 1] {
+ let mut queries_iter = queries.into_iter().peekable();
+
+ // Propagate queries and hashes to the next layer.
+ let mut next_queries = Vec::new();
+ while let Some(query) = queries_iter.next() {
+ next_queries.push(query / 2);
+ if queries_iter.next_if_eq(&(query ^ 1)).is_some() {
+ continue;
+ }
+ witness.push(layer.at(query ^ 1));
+ }
+ queries = next_queries;
+ }
+ MerkleDecommitment { witness }
+ }
+
+ pub fn root(&self) -> H::Hash {
+ self.layers.last().unwrap().at(0)
+ }
+}
+
+#[derive(Debug)]
+pub struct MerkleDecommitment {
+ pub witness: Vec,
+}
diff --git a/src/commitment_scheme/verifier.rs b/src/commitment_scheme/verifier.rs
new file mode 100644
index 000000000..dabff528f
--- /dev/null
+++ b/src/commitment_scheme/verifier.rs
@@ -0,0 +1,210 @@
+use std::cmp::Reverse;
+use std::iter::Peekable;
+
+use itertools::Itertools;
+use thiserror::Error;
+
+use super::ops::MerkleHasher;
+use super::prover::MerkleDecommitment;
+use crate::core::fields::m31::BaseField;
+
+// TODO(spapini): This struct is not necessary. Make it a function on decommitment?
+pub struct MerkleTreeVerifier {
+ pub root: H::Hash,
+}
+impl MerkleTreeVerifier {
+ /// Verifies the decommitment of the columns.
+ ///
+ /// # Arguments
+ ///
+ /// * `queries` - A vector of indices representing the queries to the largest column.
+ /// Note: This is sufficient for bit reversed STARK commitments.
+ /// It could be extended to support queries to any column.
+ /// * `values` - A vector of pairs containing the log_size of the column and the decommitted
+ /// values of the column. Must be given in the same order as the columns were committed.
+ /// * `decommitment` - The decommitment object containing the witness and column values.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if any of the following conditions are met:
+ ///
+ /// * The witness is too long (not fully consumed).
+ /// * The witness is too short (missing values).
+ /// * The column values are too long (not fully consumed).
+ /// * The column values are too short (missing values).
+ /// * The computed root does not match the expected root.
+ ///
+ /// # Panics
+ ///
+ /// This function will panic if the `values` vector is not sorted in descending order based on
+ /// the `log_size` of the columns.
+ ///
+ /// # Returns
+ ///
+ /// Returns `Ok(())` if the decommitment is successfully verified.
+ pub fn verify(
+ &self,
+ queries: Vec,
+ values: Vec<(u32, Vec)>,
+ decommitment: MerkleDecommitment,
+ ) -> Result<(), MerkleVerificationError> {
+ // Check that columns are of descending order.
+ assert!(values.is_sorted_by_key(|(log_size, _)| Reverse(log_size)));
+
+ // Compute root from decommitment.
+ let mut verifier = MerkleVerifier:: {
+ witness: decommitment.witness.into_iter(),
+ column_values: values.into_iter().peekable(),
+ layer_column_values: Vec::new(),
+ };
+ let computed_root = verifier.compute_root_from_decommitment(queries)?;
+
+ // Check that all witnesses and values have been consumed.
+ if !verifier.witness.is_empty() {
+ return Err(MerkleVerificationError::WitnessTooLong);
+ }
+ if !verifier.column_values.is_empty() {
+ return Err(MerkleVerificationError::ColumnValuesTooLong);
+ }
+
+ // Check that the computed root matches the expected root.
+ if computed_root != self.root {
+ return Err(MerkleVerificationError::RootMismatch);
+ }
+
+ Ok(())
+ }
+}
+
+/// A helper struct for verifying a [MerkleDecommitment].
+struct MerkleVerifier {
+ /// A queue for consuming the next hash witness from the decommitment.
+ witness: std::vec::IntoIter<::Hash>,
+ /// A queue for consuming the next claimed values for each column.
+ column_values: Peekable)>>,
+ /// A queue for consuming the next claimed values for each column in the current layer.
+ layer_column_values: Vec>,
+}
+impl MerkleVerifier {
+ /// Computes the root hash of a Merkle tree from the decommitment information.
+ ///
+ /// # Arguments
+ ///
+ /// * `queries` - A vector of query indices to the largest column.
+ ///
+ /// # Returns
+ ///
+ /// Returns the computed root hash of the Merkle tree.
+ ///
+ /// # Errors
+ ///
+ /// Returns a `MerkleVerificationError` if there is an error during the computation.
+ pub fn compute_root_from_decommitment(
+ &mut self,
+ queries: Vec,
+ ) -> Result {
+ let max_log_size = self.column_values.peek().unwrap().0;
+ assert!(*queries.iter().max().unwrap() < 1 << max_log_size);
+
+ // A sequence of queries to the current layer.
+ // Each query is a pair of the query index and the known hashes of the children, if any.
+ // The known hashes are represented as ChildrenHashesAtQuery.
+ // None on the largest layer, or a pair of Option, for the known hashes of the left
+ // and right children.
+ let mut queries = queries.into_iter().map(|query| (query, None)).collect_vec();
+
+ for layer_log_size in (0..=max_log_size).rev() {
+ // Take values for columns of the current log_size.
+ self.layer_column_values = (&mut self.column_values)
+ .take_while(|(log_size, _)| *log_size == layer_log_size)
+ .map(|(_, values)| values.into_iter())
+ .collect();
+
+ // Compute node hashes for the current layer.
+ let mut hashes_at_layer = queries
+ .into_iter()
+ .map(|(index, children_hashes)| (index, self.compute_node_hash(children_hashes)))
+ .peekable();
+
+ // Propagate queries and hashes to the next layer.
+ let mut next_queries = Vec::new();
+ while let Some((index, node_hash)) = hashes_at_layer.next() {
+ // If the sibling hash is known, propagate it to the next layer.
+ if let Some((_, sibling_hash)) =
+ hashes_at_layer.next_if(|(next_index, _)| *next_index == index ^ 1)
+ {
+ next_queries.push((index / 2, Some((Some(node_hash?), Some(sibling_hash?)))));
+ continue;
+ }
+ // Otherwise, propagate the node hash to the next layer, in the correct direction.
+ if index & 1 == 0 {
+ next_queries.push((index / 2, Some((Some(node_hash?), None))));
+ } else {
+ next_queries.push((index / 2, Some((None, Some(node_hash?)))));
+ }
+ }
+ queries = next_queries;
+
+ // Check that all layer_column_values have been consumed.
+ if self
+ .layer_column_values
+ .iter_mut()
+ .any(|values| values.next().is_some())
+ {
+ return Err(MerkleVerificationError::ColumnValuesTooLong);
+ }
+ }
+
+ assert_eq!(queries.len(), 1);
+ Ok(queries.pop().unwrap().1.unwrap().0.unwrap())
+ }
+
+ fn compute_node_hash(
+ &mut self,
+ children_hashes: ChildrenHashesAtQuery,
+ ) -> Result {
+ // For each child with an unknown hash, fill it from the witness queue.
+ let hashes_part = children_hashes
+ .map(|(l, r)| {
+ let l = l
+ .or_else(|| self.witness.next())
+ .ok_or(MerkleVerificationError::WitnessTooShort)?;
+ let r = r
+ .or_else(|| self.witness.next())
+ .ok_or(MerkleVerificationError::WitnessTooShort)?;
+ Ok((l, r))
+ })
+ .transpose()?;
+ // Fill the column values from the layer_column_values queue.
+ let column_values = self
+ .layer_column_values
+ .iter_mut()
+ .map(|values| {
+ values
+ .next()
+ .ok_or(MerkleVerificationError::ColumnValuesTooShort)
+ })
+ .collect::, _>>()?;
+ // Hash the node.
+ Ok(H::hash_node(hashes_part, &column_values))
+ }
+}
+
+type ChildrenHashesAtQuery = Option<(
+ Option<::Hash>,
+ Option<::Hash>,
+)>;
+
+#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
+pub enum MerkleVerificationError {
+ #[error("Witness is too short.")]
+ WitnessTooShort,
+ #[error("Witness is too long.")]
+ WitnessTooLong,
+ #[error("Column values are too long.")]
+ ColumnValuesTooLong,
+ #[error("Column values are too short.")]
+ ColumnValuesTooShort,
+ #[error("Root mismatch.")]
+ RootMismatch,
+}