From 43bb6bf509a5c0cb4d608a37267d3e9681b3c58d Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Fri, 22 Mar 2024 11:36:23 +0200 Subject: [PATCH] Simple mixed merkle tree --- src/commitment_scheme/blake2_merkle.rs | 193 +++++++++++++++++++++++ src/commitment_scheme/mod.rs | 4 + src/commitment_scheme/ops.rs | 40 +++++ src/commitment_scheme/prover.rs | 92 +++++++++++ src/commitment_scheme/verifier.rs | 210 +++++++++++++++++++++++++ 5 files changed, 539 insertions(+) create mode 100644 src/commitment_scheme/blake2_merkle.rs create mode 100644 src/commitment_scheme/ops.rs create mode 100644 src/commitment_scheme/prover.rs create mode 100644 src/commitment_scheme/verifier.rs diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs new file mode 100644 index 000000000..e3e48ff64 --- /dev/null +++ b/src/commitment_scheme/blake2_merkle.rs @@ -0,0 +1,193 @@ +use itertools::Itertools; +use num_traits::Zero; + +use super::blake2s_ref::compress; +use super::ops::{MerkleHasher, MerkleOps}; +use crate::core::backend::CPUBackend; +use crate::core::fields::m31::BaseField; + +pub struct Blake2Hasher; +impl MerkleHasher for Blake2Hasher { + type Hash = [u32; 8]; + + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash { + let mut state = [0; 8]; + if let Some((left, right)) = children_hashes { + state = compress( + state, + unsafe { std::mem::transmute([left, right]) }, + 0, + 0, + 0, + 0, + ); + } + let rem = 15 - ((column_values.len() + 15) % 16); + let padded_values = column_values + .iter() + .copied() + .chain(std::iter::repeat(BaseField::zero()).take(rem)); + for chunk in padded_values.array_chunks::<16>() { + state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + } + state + } +} + +impl MerkleOps for CPUBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec<[u32; 8]>>, + columns: &[&Vec], + ) -> Vec<[u32; 8]> { + (0..(1 << log_size)) + .map(|i| { + Blake2Hasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::Zero; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + use crate::commitment_scheme::blake2_merkle::Blake2Hasher; + use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver}; + use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError}; + use crate::core::backend::CPUBackend; + use crate::core::fields::m31::BaseField; + + type TestData = ( + Vec, + MerkleDecommitment, + Vec<(u32, Vec)>, + MerkleTreeVerifier, + ); + fn prepare_merkle() -> TestData { + const N_COLS: usize = 400; + const N_QUERIES: usize = 7; + + let rng = &mut StdRng::seed_from_u64(0); + let log_sizes = (0..N_COLS) + .map(|_| rng.gen_range(6..9)) + .sorted() + .rev() + .collect_vec(); + let max_log_size = *log_sizes.iter().max().unwrap(); + let cols = log_sizes + .iter() + .map(|&log_size| { + (0..(1 << log_size)) + .map(|_| BaseField::from(rng.gen_range(0..(1 << 30)))) + .collect_vec() + }) + .collect_vec(); + let merkle = MerkleProver::::commit(cols.iter().collect_vec()); + + let queries = (0..N_QUERIES) + .map(|_| rng.gen_range(0..(1 << max_log_size))) + .sorted() + .dedup() + .collect_vec(); + let decommitment = merkle.decommit(queries.clone()); + let values = cols + .iter() + .map(|col| { + let layer_queries = queries + .iter() + .map(|&q| q >> (max_log_size - col.len().ilog2())) + .dedup(); + layer_queries.map(|q| col[q]).collect_vec() + }) + .collect_vec(); + let values = log_sizes.into_iter().zip(values).collect_vec(); + + let verifier = MerkleTreeVerifier { + root: merkle.root(), + }; + (queries, decommitment, values, verifier) + } + + #[test] + fn test_merkle_success() { + let (queries, decommitment, values, verifier) = prepare_merkle(); + + verifier.verify(queries, values, decommitment).unwrap(); + } + + #[test] + fn test_merkle_invalid_witness() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness[20] = [0; 8]; + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_invalid_value() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1[6] = BaseField::zero(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_witness_too_short() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooShort + ); + } + + #[test] + fn test_merkle_column_values_too_long() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1.push(BaseField::zero()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_short() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooShort + ); + } + + #[test] + fn test_merkle_witness_too_long() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness.push([0; 8]); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooLong + ); + } +} diff --git a/src/commitment_scheme/mod.rs b/src/commitment_scheme/mod.rs index 23001ed95..c7108e08c 100644 --- a/src/commitment_scheme/mod.rs +++ b/src/commitment_scheme/mod.rs @@ -1,4 +1,5 @@ pub mod blake2_hash; +pub mod blake2_merkle; pub mod blake2s_ref; pub mod blake3_hash; pub mod hasher; @@ -8,4 +9,7 @@ pub mod merkle_multilayer; pub mod merkle_tree; pub mod mixed_degree_decommitment; pub mod mixed_degree_merkle_tree; +pub mod ops; +pub mod prover; pub mod utils; +pub mod verifier; diff --git a/src/commitment_scheme/ops.rs b/src/commitment_scheme/ops.rs new file mode 100644 index 000000000..65660f887 --- /dev/null +++ b/src/commitment_scheme/ops.rs @@ -0,0 +1,40 @@ +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; + +/// A Merkle node hash is a hash of: +/// [left_child_hash, right_child_hash], column0_value, column1_value, ... +/// "[]" denotes optional values. +/// The largest Merkle layer has no left and right child hashes. The rest of the layers have +/// children hashes. +/// At each layer, the tree may have multiple columns of the same length as the layer. +/// Each node in that layer contains one value from each column. +pub trait MerkleHasher { + type Hash: Clone + Eq + std::fmt::Debug; + /// Hashes a single Merkle node. See [MerkleHasher] for more details. + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash; +} + +/// Trait for performing Merkle operations on a commitment scheme. +pub trait MerkleOps: ColumnOps + ColumnOps { + /// Commits on an entire layer of the Merkle tree. + /// See [MerkleHasher] for more details. + /// + /// The layer has 2^`log_size` nodes that need to be hashed. The topmost layer has 1 node, + /// which is a hash of 2 children and some columns. + /// + /// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer. + /// That layer is assumed to have 2^(`log_size`+1) nodes. + /// + /// `columns` are the extra columns that need to be hashed in each node. + /// They are assumed to be of size 2^`log_size`. + /// + /// Returns the next Merkle layer hashes. + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Col>, + columns: &[&Col], + ) -> Col; +} diff --git a/src/commitment_scheme/prover.rs b/src/commitment_scheme/prover.rs new file mode 100644 index 000000000..8823213e2 --- /dev/null +++ b/src/commitment_scheme/prover.rs @@ -0,0 +1,92 @@ +use std::cmp::Reverse; + +use itertools::Itertools; + +use super::ops::{MerkleHasher, MerkleOps}; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; + +pub struct MerkleProver, H: MerkleHasher> { + /// Layers of the Merkle tree. + /// The first layer is the largest column. + /// The last layer is the root. + /// See [MerkleOps::commit_on_layer] for more details. + pub layers: Vec>, +} +/// The MerkleProver struct represents a prover for a Merkle commitment scheme. +/// It is generic over the types `B` and `H`, which represent the Merkle operations and Merkle +/// hasher respectively. +impl, H: MerkleHasher> MerkleProver { + /// Commits to columns. + /// Columns must be of power of 2 sizes and sorted in descending order. + /// + /// # Arguments + /// + /// * `columns` - A vector of references to columns. + /// + /// # Panics + /// + /// This function will panic if the columns are not sorted in descending order or if the columns + /// vector is empty. + /// + /// # Returns + /// + /// A new instance of `MerkleProver` with the committed layers. + pub fn commit(columns: Vec<&Col>) -> Self { + // Check that columns are of descending order. + assert!(!columns.is_empty()); + assert!(columns.is_sorted_by_key(|c| Reverse(c.len()))); + + let mut columns = &mut columns.into_iter().peekable(); + let mut layers: Vec> = Vec::new(); + + let max_log_size = columns.peek().unwrap().len().ilog2(); + for log_size in (0..=max_log_size).rev() { + // Take columns of the current log_size. + let layer_columns = (&mut columns) + .take_while(|column| column.len().ilog2() == log_size) + .collect_vec(); + + layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); + } + Self { layers } + } + + /// Decommits to columns on the given queries. + /// Queries are given as indices to the largest column. + /// + /// # Arguments + /// + /// * `queries` - A vector of query indices to the largest column. + /// + /// # Returns + /// + /// A `Decommitment` struct containing the witness. + pub fn decommit(&self, mut queries: Vec) -> MerkleDecommitment { + let mut witness = Vec::new(); + for layer in &self.layers[..self.layers.len() - 1] { + let mut queries_iter = queries.into_iter().peekable(); + + // Propagate queries and hashes to the next layer. + let mut next_queries = Vec::new(); + while let Some(query) = queries_iter.next() { + next_queries.push(query / 2); + if queries_iter.next_if_eq(&(query ^ 1)).is_some() { + continue; + } + witness.push(layer.at(query ^ 1)); + } + queries = next_queries; + } + MerkleDecommitment { witness } + } + + pub fn root(&self) -> H::Hash { + self.layers.last().unwrap().at(0) + } +} + +#[derive(Debug)] +pub struct MerkleDecommitment { + pub witness: Vec, +} diff --git a/src/commitment_scheme/verifier.rs b/src/commitment_scheme/verifier.rs new file mode 100644 index 000000000..dabff528f --- /dev/null +++ b/src/commitment_scheme/verifier.rs @@ -0,0 +1,210 @@ +use std::cmp::Reverse; +use std::iter::Peekable; + +use itertools::Itertools; +use thiserror::Error; + +use super::ops::MerkleHasher; +use super::prover::MerkleDecommitment; +use crate::core::fields::m31::BaseField; + +// TODO(spapini): This struct is not necessary. Make it a function on decommitment? +pub struct MerkleTreeVerifier { + pub root: H::Hash, +} +impl MerkleTreeVerifier { + /// Verifies the decommitment of the columns. + /// + /// # Arguments + /// + /// * `queries` - A vector of indices representing the queries to the largest column. + /// Note: This is sufficient for bit reversed STARK commitments. + /// It could be extended to support queries to any column. + /// * `values` - A vector of pairs containing the log_size of the column and the decommitted + /// values of the column. Must be given in the same order as the columns were committed. + /// * `decommitment` - The decommitment object containing the witness and column values. + /// + /// # Errors + /// + /// Returns an error if any of the following conditions are met: + /// + /// * The witness is too long (not fully consumed). + /// * The witness is too short (missing values). + /// * The column values are too long (not fully consumed). + /// * The column values are too short (missing values). + /// * The computed root does not match the expected root. + /// + /// # Panics + /// + /// This function will panic if the `values` vector is not sorted in descending order based on + /// the `log_size` of the columns. + /// + /// # Returns + /// + /// Returns `Ok(())` if the decommitment is successfully verified. + pub fn verify( + &self, + queries: Vec, + values: Vec<(u32, Vec)>, + decommitment: MerkleDecommitment, + ) -> Result<(), MerkleVerificationError> { + // Check that columns are of descending order. + assert!(values.is_sorted_by_key(|(log_size, _)| Reverse(log_size))); + + // Compute root from decommitment. + let mut verifier = MerkleVerifier:: { + witness: decommitment.witness.into_iter(), + column_values: values.into_iter().peekable(), + layer_column_values: Vec::new(), + }; + let computed_root = verifier.compute_root_from_decommitment(queries)?; + + // Check that all witnesses and values have been consumed. + if !verifier.witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } + if !verifier.column_values.is_empty() { + return Err(MerkleVerificationError::ColumnValuesTooLong); + } + + // Check that the computed root matches the expected root. + if computed_root != self.root { + return Err(MerkleVerificationError::RootMismatch); + } + + Ok(()) + } +} + +/// A helper struct for verifying a [MerkleDecommitment]. +struct MerkleVerifier { + /// A queue for consuming the next hash witness from the decommitment. + witness: std::vec::IntoIter<::Hash>, + /// A queue for consuming the next claimed values for each column. + column_values: Peekable)>>, + /// A queue for consuming the next claimed values for each column in the current layer. + layer_column_values: Vec>, +} +impl MerkleVerifier { + /// Computes the root hash of a Merkle tree from the decommitment information. + /// + /// # Arguments + /// + /// * `queries` - A vector of query indices to the largest column. + /// + /// # Returns + /// + /// Returns the computed root hash of the Merkle tree. + /// + /// # Errors + /// + /// Returns a `MerkleVerificationError` if there is an error during the computation. + pub fn compute_root_from_decommitment( + &mut self, + queries: Vec, + ) -> Result { + let max_log_size = self.column_values.peek().unwrap().0; + assert!(*queries.iter().max().unwrap() < 1 << max_log_size); + + // A sequence of queries to the current layer. + // Each query is a pair of the query index and the known hashes of the children, if any. + // The known hashes are represented as ChildrenHashesAtQuery. + // None on the largest layer, or a pair of Option, for the known hashes of the left + // and right children. + let mut queries = queries.into_iter().map(|query| (query, None)).collect_vec(); + + for layer_log_size in (0..=max_log_size).rev() { + // Take values for columns of the current log_size. + self.layer_column_values = (&mut self.column_values) + .take_while(|(log_size, _)| *log_size == layer_log_size) + .map(|(_, values)| values.into_iter()) + .collect(); + + // Compute node hashes for the current layer. + let mut hashes_at_layer = queries + .into_iter() + .map(|(index, children_hashes)| (index, self.compute_node_hash(children_hashes))) + .peekable(); + + // Propagate queries and hashes to the next layer. + let mut next_queries = Vec::new(); + while let Some((index, node_hash)) = hashes_at_layer.next() { + // If the sibling hash is known, propagate it to the next layer. + if let Some((_, sibling_hash)) = + hashes_at_layer.next_if(|(next_index, _)| *next_index == index ^ 1) + { + next_queries.push((index / 2, Some((Some(node_hash?), Some(sibling_hash?))))); + continue; + } + // Otherwise, propagate the node hash to the next layer, in the correct direction. + if index & 1 == 0 { + next_queries.push((index / 2, Some((Some(node_hash?), None)))); + } else { + next_queries.push((index / 2, Some((None, Some(node_hash?))))); + } + } + queries = next_queries; + + // Check that all layer_column_values have been consumed. + if self + .layer_column_values + .iter_mut() + .any(|values| values.next().is_some()) + { + return Err(MerkleVerificationError::ColumnValuesTooLong); + } + } + + assert_eq!(queries.len(), 1); + Ok(queries.pop().unwrap().1.unwrap().0.unwrap()) + } + + fn compute_node_hash( + &mut self, + children_hashes: ChildrenHashesAtQuery, + ) -> Result { + // For each child with an unknown hash, fill it from the witness queue. + let hashes_part = children_hashes + .map(|(l, r)| { + let l = l + .or_else(|| self.witness.next()) + .ok_or(MerkleVerificationError::WitnessTooShort)?; + let r = r + .or_else(|| self.witness.next()) + .ok_or(MerkleVerificationError::WitnessTooShort)?; + Ok((l, r)) + }) + .transpose()?; + // Fill the column values from the layer_column_values queue. + let column_values = self + .layer_column_values + .iter_mut() + .map(|values| { + values + .next() + .ok_or(MerkleVerificationError::ColumnValuesTooShort) + }) + .collect::, _>>()?; + // Hash the node. + Ok(H::hash_node(hashes_part, &column_values)) + } +} + +type ChildrenHashesAtQuery = Option<( + Option<::Hash>, + Option<::Hash>, +)>; + +#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] +pub enum MerkleVerificationError { + #[error("Witness is too short.")] + WitnessTooShort, + #[error("Witness is too long.")] + WitnessTooLong, + #[error("Column values are too long.")] + ColumnValuesTooLong, + #[error("Column values are too short.")] + ColumnValuesTooShort, + #[error("Root mismatch.")] + RootMismatch, +}