From 74d0ad52fa20eb9153a81551cd1c0adeeab3412b Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Fri, 22 Mar 2024 11:36:23 +0200 Subject: [PATCH] Simple mixed merkle tree --- src/commitment_scheme/blake2_merkle.rs | 193 +++++++++++++++++++++++++ src/commitment_scheme/mod.rs | 4 + src/commitment_scheme/ops.rs | 30 ++++ src/commitment_scheme/prover.rs | 65 +++++++++ src/commitment_scheme/verifier.rs | 164 +++++++++++++++++++++ 5 files changed, 456 insertions(+) create mode 100644 src/commitment_scheme/blake2_merkle.rs create mode 100644 src/commitment_scheme/ops.rs create mode 100644 src/commitment_scheme/prover.rs create mode 100644 src/commitment_scheme/verifier.rs diff --git a/src/commitment_scheme/blake2_merkle.rs b/src/commitment_scheme/blake2_merkle.rs new file mode 100644 index 000000000..7e65d556d --- /dev/null +++ b/src/commitment_scheme/blake2_merkle.rs @@ -0,0 +1,193 @@ +use itertools::Itertools; +use num_traits::Zero; + +use super::blake2s_ref::compress; +use super::ops::{MerkleHasher, MerkleOps}; +use crate::core::backend::CPUBackend; +use crate::core::fields::m31::BaseField; + +pub struct Blake2Hasher; +impl MerkleHasher for Blake2Hasher { + type Hash = [u32; 8]; + + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash { + let mut state = [0; 8]; + if let Some((left, right)) = children_hashes { + state = compress( + state, + unsafe { std::mem::transmute([left, right]) }, + 0, + 0, + 0, + 0, + ); + } + let rem = 15 - ((column_values.len() + 15) % 16); + let padded_values = column_values + .iter() + .copied() + .chain(std::iter::repeat(BaseField::zero()).take(rem)); + for chunk in padded_values.array_chunks::<16>() { + state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + } + state + } +} + +impl MerkleOps for CPUBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec<[u32; 8]>>, + columns: &[&Vec], + ) -> Vec<[u32; 8]> { + (0..(1 << log_size)) + .map(|i| { + Blake2Hasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::Zero; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + use crate::commitment_scheme::blake2_merkle::Blake2Hasher; + use crate::commitment_scheme::prover::{Decommitment, MerkleProver}; + use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError}; + use crate::core::backend::CPUBackend; + use crate::core::fields::m31::BaseField; + + type TestData = ( + Vec, + Decommitment, + Vec<(u32, Vec)>, + MerkleTreeVerifier, + ); + fn prepare_merkle() -> TestData { + const N_COLS: usize = 400; + const N_QUERIES: usize = 7; + + let rng = &mut StdRng::seed_from_u64(0); + let log_sizes = (0..N_COLS) + .map(|_| rng.gen_range(6..9)) + .sorted() + .rev() + .collect_vec(); + let max_log_size = *log_sizes.iter().max().unwrap(); + let cols = log_sizes + .iter() + .map(|&log_size| { + (0..(1 << log_size)) + .map(|_| BaseField::from(rng.gen_range(0..(1 << 30)))) + .collect_vec() + }) + .collect_vec(); + let merkle = MerkleProver::::commit(cols.iter().collect_vec()); + + let queries = (0..N_QUERIES) + .map(|_| rng.gen_range(0..(1 << max_log_size))) + .sorted() + .dedup() + .collect_vec(); + let decommitment = merkle.decommit(queries.clone()); + let values = cols + .iter() + .map(|col| { + let layer_queries = queries + .iter() + .map(|&q| q >> (max_log_size - col.len().ilog2())) + .dedup(); + layer_queries.map(|q| col[q]).collect_vec() + }) + .collect_vec(); + let values = log_sizes.into_iter().zip(values).collect_vec(); + + let verifier = MerkleTreeVerifier { + root: merkle.root(), + }; + (queries, decommitment, values, verifier) + } + + #[test] + fn test_merkle_success() { + let (queries, decommitment, values, verifier) = prepare_merkle(); + + verifier.verify(queries, values, decommitment).unwrap(); + } + + #[test] + fn test_merkle_invalid_witness() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness[20] = [0; 8]; + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_invalid_value() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1[6] = BaseField::zero(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_witness_too_short() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooShort + ); + } + + #[test] + fn test_merkle_column_values_too_long() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1.push(BaseField::zero()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_short() { + let (queries, decommitment, mut values, verifier) = prepare_merkle(); + values[3].1.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooShort + ); + } + + #[test] + fn test_merkle_witness_too_long() { + let (queries, mut decommitment, values, verifier) = prepare_merkle(); + decommitment.witness.push([0; 8]); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooLong + ); + } +} diff --git a/src/commitment_scheme/mod.rs b/src/commitment_scheme/mod.rs index 688f2c8e1..689295051 100644 --- a/src/commitment_scheme/mod.rs +++ b/src/commitment_scheme/mod.rs @@ -1,4 +1,5 @@ pub mod blake2_hash; +pub mod blake2_merkle; #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] pub mod blake2s_avx; pub mod blake2s_ref; @@ -10,4 +11,7 @@ pub mod merkle_multilayer; pub mod merkle_tree; pub mod mixed_degree_decommitment; pub mod mixed_degree_merkle_tree; +pub mod ops; +pub mod prover; pub mod utils; +pub mod verifier; diff --git a/src/commitment_scheme/ops.rs b/src/commitment_scheme/ops.rs new file mode 100644 index 000000000..94e3c74c7 --- /dev/null +++ b/src/commitment_scheme/ops.rs @@ -0,0 +1,30 @@ +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; + +pub trait MerkleHasher { + type Hash: Clone + Eq + std::fmt::Debug; + /// Hashes a single Merkle node. + /// The node may or may not need to hash 2 hashes from the previous layer - depending if it is a + /// leaf or not. + /// In addition, the node may have extra column values that need to be hashed. + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash; +} + +pub trait MerkleOps: ColumnOps + ColumnOps { + /// Commits on an entire layer of the Merkle tree. + /// The layer has 2^`log_size` nodes that need be hashed. The top most layer has 1 node, + /// which is a hash of 2 children and some columns. + /// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer.. + /// That layer is assumed to have 2^(`log_size`+1) nodes. + /// `columns` are the extra columns that need to be hashed in each node. + /// They are assumed to be on size 2^`log_size`. + /// Return the next Merkle layer hashes. + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Col>, + columns: &[&Col], + ) -> Col; +} diff --git a/src/commitment_scheme/prover.rs b/src/commitment_scheme/prover.rs new file mode 100644 index 000000000..b1d364130 --- /dev/null +++ b/src/commitment_scheme/prover.rs @@ -0,0 +1,65 @@ +use std::cmp::Reverse; + +use itertools::Itertools; + +use super::ops::{MerkleHasher, MerkleOps}; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; + +pub struct MerkleProver, H: MerkleHasher> { + pub layers: Vec>, +} +impl, H: MerkleHasher> MerkleProver { + /// Commits to columns. + /// Columns must be of power of 2 sizes and sorted in descending order. + pub fn commit(columns: Vec<&Col>) -> Self { + // Check that columns are of descending order. + assert!(!columns.is_empty()); + assert!(columns.is_sorted_by_key(|c| Reverse(c.len()))); + + let mut columns = &mut columns.into_iter().peekable(); + let mut layers: Vec> = Vec::new(); + + let max_log_size = columns.peek().unwrap().len().ilog2(); + for log_size in (0..=max_log_size).rev() { + // Take columns of the current log_size. + let layer_columns = (&mut columns) + .take_while(|column| column.len().ilog2() == log_size) + .collect_vec(); + + layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); + } + Self { layers } + } + + /// Decommits to columns on the given queries. + /// Queries are given as indices to the largest column. + pub fn decommit(&self, mut queries: Vec) -> Decommitment { + let mut witness = Vec::new(); + for layer in &self.layers { + let mut queries_iter = queries.into_iter().peekable(); + + // Propagate queries and hashes to the next layer. + let mut next_queries = Vec::new(); + while let Some(query) = queries_iter.next() { + next_queries.push(query / 2); + if queries_iter.next_if_eq(&(query ^ 1)).is_some() { + continue; + } + if layer.len() > 1 { + witness.push(layer.at(query ^ 1)); + } + } + queries = next_queries; + } + Decommitment { witness } + } + + pub fn root(&self) -> H::Hash { + self.layers.last().unwrap().at(0) + } +} + +pub struct Decommitment { + pub witness: Vec, +} diff --git a/src/commitment_scheme/verifier.rs b/src/commitment_scheme/verifier.rs new file mode 100644 index 000000000..b2c890ef0 --- /dev/null +++ b/src/commitment_scheme/verifier.rs @@ -0,0 +1,164 @@ +use std::cmp::Reverse; +use std::iter::Peekable; + +use itertools::Itertools; +use thiserror::Error; + +use super::ops::MerkleHasher; +use super::prover::Decommitment; +use crate::core::fields::m31::BaseField; + +pub struct MerkleTreeVerifier { + pub root: H::Hash, +} +impl MerkleTreeVerifier { + /// Verifies the decommitment of the columns. + /// Queries are given as indices to the largest column. + /// Values are given as pair of log_size of the column, and the decommited values of the + /// column. + /// Must be given in the same order as the columns were committed. + pub fn verify( + &self, + queries: Vec, + values: Vec<(u32, Vec)>, + decommitment: Decommitment, + ) -> Result<(), MerkleVerificationError> { + // Check that columns are of descending order. + assert!(values.is_sorted_by_key(|(log_size, _)| Reverse(log_size))); + + // Compute root from decommitment. + let mut verifier = MerkleVerifier:: { + witness: decommitment.witness.into_iter(), + column_values: values.into_iter().peekable(), + layer_column_values: Vec::new(), + }; + let computed_root = verifier.compute_root_from_decommitment(queries)?; + + // Check that all witnesses and values have been consumed. + if !verifier.witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } + if !verifier.column_values.is_empty() { + return Err(MerkleVerificationError::ColumnValuesTooLong); + } + + // Check that the computed root matches the expected root. + if computed_root != self.root { + return Err(MerkleVerificationError::RootMismatch); + } + + Ok(()) + } +} + +struct MerkleVerifier { + witness: std::vec::IntoIter<::Hash>, + column_values: Peekable)>>, + layer_column_values: Vec>, +} +impl MerkleVerifier { + pub fn compute_root_from_decommitment( + &mut self, + queries: Vec, + ) -> Result { + let max_log_size = self.column_values.peek().unwrap().0; + + // A sequence of queries to the current layer. + // Each query is a pair of the query index and the known hashes of the children, if any. + // The known hashes are represented as ChildrenHashesAtQuery. + // None on the largest layer, or a pair of Option, for the known hashes of the left + // and right children. + let mut queries = queries.into_iter().map(|query| (query, None)).collect_vec(); + + for layer_log_size in (0..=max_log_size).rev() { + // Take values for columns of the current log_size. + self.layer_column_values = (&mut self.column_values) + .take_while(|(log_size, _)| *log_size == layer_log_size) + .map(|(_, values)| values.into_iter()) + .collect(); + + // Compute node hashes for the current layer. + let mut hashes_at_layer = queries + .into_iter() + .map(|(index, children_hashes)| (index, self.compute_node_hash(children_hashes))) + .peekable(); + + // Propagate queries and hashes to the next layer. + let mut next_queries = Vec::new(); + while let Some((index, node_hash)) = hashes_at_layer.next() { + // If the sibling hash is known, propagate it to the next layer. + if let Some((_, sibling_hash)) = + hashes_at_layer.next_if(|(next_index, _)| *next_index == index ^ 1) + { + next_queries.push((index / 2, Some((Some(node_hash?), Some(sibling_hash?))))); + continue; + } + // Otherwise, propagate the node hash to the next layer, in the correct direction. + if index & 1 == 0 { + next_queries.push((index / 2, Some((Some(node_hash?), None)))); + } else { + next_queries.push((index / 2, Some((None, Some(node_hash?))))); + } + } + queries = next_queries; + + // Check that all layer_column_values have been consumed. + if self + .layer_column_values + .iter_mut() + .any(|values| values.next().is_some()) + { + return Err(MerkleVerificationError::ColumnValuesTooLong); + } + } + + assert_eq!(queries.len(), 1); + Ok(queries.pop().unwrap().1.unwrap().0.unwrap()) + } + + fn compute_node_hash( + &mut self, + children_hashes: ChildrenHashesAtQuery, + ) -> Result { + let hashes_part = children_hashes + .map(|(l, r)| { + let l = l + .or_else(|| self.witness.next()) + .ok_or(MerkleVerificationError::WitnessTooShort)?; + let r = r + .or_else(|| self.witness.next()) + .ok_or(MerkleVerificationError::WitnessTooShort)?; + Ok((l, r)) + }) + .transpose()?; + let column_values = self + .layer_column_values + .iter_mut() + .map(|values| { + values + .next() + .ok_or(MerkleVerificationError::ColumnValuesTooShort) + }) + .collect::, _>>()?; + Ok(H::hash_node(hashes_part, &column_values)) + } +} + +type ChildrenHashesAtQuery = Option<( + Option<::Hash>, + Option<::Hash>, +)>; + +#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] +pub enum MerkleVerificationError { + #[error("Witness is too short.")] + WitnessTooShort, + #[error("Witness is too long.")] + WitnessTooLong, + #[error("Column values are too long.")] + ColumnValuesTooLong, + #[error("Column values are too short.")] + ColumnValuesTooShort, + #[error("Root mismatch.")] + RootMismatch, +}