Skip to content

Commit

Permalink
merkle intermediate layers
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 28, 2024
1 parent 136858a commit 9030ba8
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 195 deletions.
81 changes: 44 additions & 37 deletions src/commitment_scheme/blake2_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

#[derive(Copy, Clone, PartialEq, Eq, Default)]
pub struct Blake2sHash(pub [u32; 8]);
pub struct Blake2Hasher;
impl MerkleHasher for Blake2Hasher {
type Hash = [u32; 8];
type Hash = Blake2sHash;

fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
Expand All @@ -33,16 +35,26 @@ impl MerkleHasher for Blake2Hasher {
for chunk in padded_values.array_chunks::<16>() {
state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
}
state
Blake2sHash(state)
}
}

impl std::fmt::Debug for Blake2sHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Write as hex.
for &byte in self.0.iter() {
write!(f, "{:02x}", byte)?;
}
Ok(())
}
}

impl MerkleOps<Blake2Hasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<[u32; 8]>>,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<[u32; 8]> {
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2Hasher::hash_node(
Expand All @@ -56,34 +68,34 @@ impl MerkleOps<Blake2Hasher> for CPUBackend {

#[cfg(test)]
mod tests {
use std::collections::BTreeMap;

use itertools::Itertools;
use num_traits::Zero;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
use crate::commitment_scheme::blake2_merkle::{Blake2Hasher, Blake2sHash};
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
use crate::commitment_scheme::verifier::{MerkleVerificationError, MerkleVerifier};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

type TestData = (
Vec<usize>,
BTreeMap<u32, Vec<usize>>,
MerkleDecommitment<Blake2Hasher>,
Vec<(u32, Vec<BaseField>)>,
MerkleTreeVerifier<Blake2Hasher>,
Vec<Vec<BaseField>>,
MerkleVerifier<Blake2Hasher>,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
const N_QUERIES: usize = 7;
let log_size_range = 6..9;

let rng = &mut StdRng::seed_from_u64(0);
let log_sizes = (0..N_COLS)
.map(|_| rng.gen_range(6..9))
.sorted()
.rev()
.map(|_| rng.gen_range(log_size_range.clone()))
.collect_vec();
let max_log_size = *log_sizes.iter().max().unwrap();
let cols = log_sizes
.iter()
.map(|&log_size| {
Expand All @@ -94,26 +106,21 @@ mod tests {
.collect_vec();
let merkle = MerkleProver::<CPUBackend, Blake2Hasher>::commit(cols.iter().collect_vec());

let queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << max_log_size)))
.sorted()
.dedup()
.collect_vec();
let decommitment = merkle.decommit(queries.clone());
let values = cols
.iter()
.map(|col| {
let layer_queries = queries
.iter()
.map(|&q| q >> (max_log_size - col.len().ilog2()))
.dedup();
layer_queries.map(|q| col[q]).collect_vec()
})
.collect_vec();
let values = log_sizes.into_iter().zip(values).collect_vec();
let mut queries = BTreeMap::<u32, Vec<usize>>::new();
for log_size in log_size_range.rev() {
let layer_queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << log_size)))
.sorted()
.dedup()
.collect_vec();
queries.insert(log_size, layer_queries);
}

let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec());

let verifier = MerkleTreeVerifier {
let verifier = MerkleVerifier {
root: merkle.root(),
column_log_sizes: log_sizes,
};
(queries, decommitment, values, verifier)
}
Expand All @@ -128,7 +135,7 @@ mod tests {
#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness[20] = [0; 8];
decommitment.hash_witness[20] = Blake2sHash([0; 8]);

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -139,7 +146,7 @@ mod tests {
#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1[6] = BaseField::zero();
values[3][6] = BaseField::zero();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -150,7 +157,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_short() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.pop();
decommitment.hash_witness.pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -161,7 +168,7 @@ mod tests {
#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1.push(BaseField::zero());
values[3].push(BaseField::zero());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -172,7 +179,7 @@ mod tests {
#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1.pop();
values[3].pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand All @@ -183,7 +190,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.push([0; 8]);
decommitment.hash_witness.push(Blake2sHash([0; 8]));

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand Down
159 changes: 137 additions & 22 deletions src/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::cmp::Reverse;
use std::collections::BTreeMap;

use itertools::Itertools;

use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::utils::PeekableExt;

pub struct MerkleProver<B: MerkleOps<H>, H: MerkleHasher> {
/// Layers of the Merkle tree.
/// TODO::
/// The first layer is the largest column.
/// The last layer is the root.
/// See [MerkleOps::commit_on_layer] for more details.
Expand All @@ -33,22 +36,24 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
///
/// A new instance of `MerkleProver` with the committed layers.
pub fn commit(columns: Vec<&Col<B, BaseField>>) -> Self {
// Check that columns are of descending order.
assert!(!columns.is_empty());
assert!(columns.is_sorted_by_key(|c| Reverse(c.len())));

let mut columns = &mut columns.into_iter().peekable();
let columns = &mut columns
.into_iter()
.sorted_by_key(|c| Reverse(c.len()))
.peekable();
let mut layers: Vec<Col<B, H::Hash>> = Vec::new();

let max_log_size = columns.peek().unwrap().len().ilog2();
for log_size in (0..=max_log_size).rev() {
// Take columns of the current log_size.
let layer_columns = (&mut columns)
.take_while(|column| column.len().ilog2() == log_size)
let layer_columns = columns
.peek_take_while(|column| column.len().ilog2() == log_size)
.collect_vec();

layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns));
}
layers.reverse();
Self { layers }
}

Expand All @@ -57,36 +62,146 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
///
/// # Arguments
///
/// * `queries` - A vector of query indices to the largest column.
/// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that
/// log_size.
///
/// # Returns
///
/// A `Decommitment` struct containing the witness.
pub fn decommit(&self, mut queries: Vec<usize>) -> MerkleDecommitment<H> {
let mut witness = Vec::new();
for layer in &self.layers[..self.layers.len() - 1] {
let mut queries_iter = queries.into_iter().peekable();

// Propagate queries and hashes to the next layer.
let mut next_queries = Vec::new();
while let Some(query) = queries_iter.next() {
next_queries.push(query / 2);
if queries_iter.next_if_eq(&(query ^ 1)).is_some() {
continue;
pub fn decommit(
&self,
queries_per_log_size: BTreeMap<u32, Vec<usize>>,
columns: Vec<&Col<B, BaseField>>,
) -> (Vec<Vec<BaseField>>, MerkleDecommitment<H>) {
// Prepare output buffers.
let mut queried_values_by_layer = vec![];
let mut decommitment = MerkleDecommitment::empty();

// Sort columns by layer.
let mut columns_by_layer = columns
.iter()
.sorted_by_key(|c| Reverse(c.len()))
.peekable();

// Check that queries are sorted and deduped.
for queries in queries_per_log_size.values() {
assert!(
queries.windows(2).all(|w| w[0] < w[1]),
"Queries are not sorted."
);
}

let mut last_layer_queries = vec![];
for layer_log_size in (0..self.layers.len() as u32).rev() {
// Prepare write buffer for queried values to the current layer.
let mut layer_queried_values = vec![];

// Prepare write buffer for queries to the current layer. This will propagate to the
// next layer.
let mut layer_total_queries = vec![];

// Each layer node is a hash of column values as previous layer hashes.
// Prepare the relevant columns and previous layer hashes to read from.
let layer_columns = columns_by_layer
.peek_take_while(|column| column.len().ilog2() == layer_log_size)
.collect_vec();
let previous_layer_hashes = self.layers.get(layer_log_size as usize + 1);

// Queries to this layer come from queried node in the previous layer and queried
// columns in this one.
let mut prev_layer_queries = last_layer_queries.into_iter().peekable();
let mut layer_column_queries = queries_per_log_size
.get(&layer_log_size)
.into_iter()
.flatten()
.copied()
.peekable();

// Merge previous layer queries and column queries.
while let Some(node_index) = prev_layer_queries
.peek()
.map(|q| *q / 2)
.into_iter()
.chain(layer_column_queries.peek().into_iter().copied())
.min()
{
if let Some(previous_layer_hashes) = previous_layer_hashes {
// If the left child was not computed, add it to the witness.
if prev_layer_queries.next_if_eq(&(2 * node_index)).is_none() {
decommitment
.hash_witness
.push(previous_layer_hashes.at(2 * node_index));
}

// If the right child was not computed, add it to the witness.
if prev_layer_queries
.next_if_eq(&(2 * node_index + 1))
.is_none()
{
decommitment
.hash_witness
.push(previous_layer_hashes.at(2 * node_index + 1));
}
}
witness.push(layer.at(query ^ 1));

// If the column values were queried, return them.
let node_values = layer_columns.iter().map(|c| c.at(node_index));
if layer_column_queries.next_if_eq(&node_index).is_some() {
layer_queried_values.push(node_values.collect_vec());
} else {
// Otherwise, add them to the witness.
decommitment.column_witness.extend(node_values);
}

layer_total_queries.push(node_index);
}
queries = next_queries;

queried_values_by_layer.push(layer_queried_values);

// Propagate queries to the next layer.
last_layer_queries = layer_total_queries;
}
MerkleDecommitment { witness }
queried_values_by_layer.reverse();

// Rearrange returned queried values according to input, and not by layer.
let mut queried_values_by_layer = queried_values_by_layer
.into_iter()
.map(|layer_results| {
layer_results
.into_iter()
.map(|x| x.into_iter())
.collect_vec()
})
.collect_vec();

let queried_values = columns
.iter()
.map(|column| {
let a = queried_values_by_layer
.get_mut(column.len().ilog2() as usize)
.unwrap();
a.iter_mut().map(|x| x.next().unwrap()).collect_vec()
})
.collect_vec();

(queried_values, decommitment)
}

pub fn root(&self) -> H::Hash {
self.layers.last().unwrap().at(0)
self.layers.first().unwrap().at(0)
}
}

#[derive(Debug)]
pub struct MerkleDecommitment<H: MerkleHasher> {
pub witness: Vec<H::Hash>,
pub hash_witness: Vec<H::Hash>,
pub column_witness: Vec<BaseField>,
}
impl<H: MerkleHasher> MerkleDecommitment<H> {
fn empty() -> Self {
Self {
hash_witness: Vec::new(),
column_witness: Vec::new(),
}
}
}
Loading

0 comments on commit 9030ba8

Please sign in to comment.