Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple mixed merkle tree #525

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions src/commitment_scheme/blake2_merkle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use itertools::Itertools;
use num_traits::Zero;

use super::blake2s_ref::compress;
use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

pub struct Blake2Hasher;
impl MerkleHasher for Blake2Hasher {
type Hash = [u32; 8];

fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
) -> Self::Hash {
let mut state = [0; 8];
if let Some((left, right)) = children_hashes {
state = compress(
state,
unsafe { std::mem::transmute([left, right]) },
0,
0,
0,
0,
);
}
let rem = 15 - ((column_values.len() + 15) % 16);
let padded_values = column_values
.iter()
.copied()
.chain(std::iter::repeat(BaseField::zero()).take(rem));
for chunk in padded_values.array_chunks::<16>() {
state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
}
state
}
}

impl MerkleOps<Blake2Hasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<[u32; 8]>>,
columns: &[&Vec<BaseField>],
) -> Vec<[u32; 8]> {
(0..(1 << log_size))
.map(|i| {
Blake2Hasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::Zero;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

type TestData = (
Vec<usize>,
MerkleDecommitment<Blake2Hasher>,
Vec<(u32, Vec<BaseField>)>,
MerkleTreeVerifier<Blake2Hasher>,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
const N_QUERIES: usize = 7;

let rng = &mut StdRng::seed_from_u64(0);
let log_sizes = (0..N_COLS)
.map(|_| rng.gen_range(6..9))
.sorted()
.rev()
.collect_vec();
let max_log_size = *log_sizes.iter().max().unwrap();
let cols = log_sizes
.iter()
.map(|&log_size| {
(0..(1 << log_size))
.map(|_| BaseField::from(rng.gen_range(0..(1 << 30))))
.collect_vec()
})
.collect_vec();
let merkle = MerkleProver::<CPUBackend, Blake2Hasher>::commit(cols.iter().collect_vec());

let queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << max_log_size)))
.sorted()
.dedup()
.collect_vec();
let decommitment = merkle.decommit(queries.clone());
let values = cols
.iter()
.map(|col| {
let layer_queries = queries
.iter()
.map(|&q| q >> (max_log_size - col.len().ilog2()))
.dedup();
layer_queries.map(|q| col[q]).collect_vec()
})
.collect_vec();
let values = log_sizes.into_iter().zip(values).collect_vec();

let verifier = MerkleTreeVerifier {
root: merkle.root(),
};
(queries, decommitment, values, verifier)
}

#[test]
fn test_merkle_success() {
let (queries, decommitment, values, verifier) = prepare_merkle();

verifier.verify(queries, values, decommitment).unwrap();
}

#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness[20] = [0; 8];

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::RootMismatch
);
}

#[test]
fn test_merkle_invalid_value() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1[6] = BaseField::zero();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::RootMismatch
);
}

#[test]
fn test_merkle_witness_too_short() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::WitnessTooShort
);
}

#[test]
fn test_merkle_column_values_too_long() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1.push(BaseField::zero());

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooLong
);
}

#[test]
fn test_merkle_column_values_too_short() {
let (queries, decommitment, mut values, verifier) = prepare_merkle();
values[3].1.pop();

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::ColumnValuesTooShort
);
}

#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.push([0; 8]);

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
MerkleVerificationError::WitnessTooLong
);
}
}
4 changes: 4 additions & 0 deletions src/commitment_scheme/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod blake2_hash;
pub mod blake2_merkle;
pub mod blake2s_ref;
pub mod blake3_hash;
pub mod hasher;
Expand All @@ -8,4 +9,7 @@ pub mod merkle_multilayer;
pub mod merkle_tree;
pub mod mixed_degree_decommitment;
pub mod mixed_degree_merkle_tree;
pub mod ops;
pub mod prover;
pub mod utils;
pub mod verifier;
40 changes: 40 additions & 0 deletions src/commitment_scheme/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use crate::core::backend::{Col, ColumnOps};
use crate::core::fields::m31::BaseField;

/// A Merkle node hash is a hash of:
/// [left_child_hash, right_child_hash], column0_value, column1_value, ...
/// "[]" denotes optional values.
/// The largest Merkle layer has no left and right child hashes. The rest of the layers have
/// children hashes.
/// At each layer, the tree may have multiple columns of the same length as the layer.
/// Each node in that layer contains one value from each column.
pub trait MerkleHasher {
type Hash: Clone + Eq + std::fmt::Debug;
/// Hashes a single Merkle node. See [MerkleHasher] for more details.
fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
) -> Self::Hash;
}

/// Trait for performing Merkle operations on a commitment scheme.
pub trait MerkleOps<H: MerkleHasher>: ColumnOps<BaseField> + ColumnOps<H::Hash> {
/// Commits on an entire layer of the Merkle tree.
/// See [MerkleHasher] for more details.
///
/// The layer has 2^`log_size` nodes that need to be hashed. The topmost layer has 1 node,
/// which is a hash of 2 children and some columns.
///
/// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer.
/// That layer is assumed to have 2^(`log_size`+1) nodes.
///
/// `columns` are the extra columns that need to be hashed in each node.
/// They are assumed to be of size 2^`log_size`.
///
/// Returns the next Merkle layer hashes.
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Col<Self, H::Hash>>,
columns: &[&Col<Self, BaseField>],
) -> Col<Self, H::Hash>;
}
92 changes: 92 additions & 0 deletions src/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::cmp::Reverse;

use itertools::Itertools;

use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;

pub struct MerkleProver<B: MerkleOps<H>, H: MerkleHasher> {
/// Layers of the Merkle tree.
/// The first layer is the largest column.
/// The last layer is the root.
/// See [MerkleOps::commit_on_layer] for more details.
pub layers: Vec<Col<B, H::Hash>>,
}
/// The MerkleProver struct represents a prover for a Merkle commitment scheme.
/// It is generic over the types `B` and `H`, which represent the Merkle operations and Merkle
/// hasher respectively.
impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
/// Commits to columns.
/// Columns must be of power of 2 sizes and sorted in descending order.
///
/// # Arguments
///
/// * `columns` - A vector of references to columns.
///
/// # Panics
///
/// This function will panic if the columns are not sorted in descending order or if the columns
/// vector is empty.
///
/// # Returns
///
/// A new instance of `MerkleProver` with the committed layers.
pub fn commit(columns: Vec<&Col<B, BaseField>>) -> Self {
// Check that columns are of descending order.
assert!(!columns.is_empty());
assert!(columns.is_sorted_by_key(|c| Reverse(c.len())));

let mut columns = &mut columns.into_iter().peekable();
let mut layers: Vec<Col<B, H::Hash>> = Vec::new();

let max_log_size = columns.peek().unwrap().len().ilog2();
for log_size in (0..=max_log_size).rev() {
// Take columns of the current log_size.
let layer_columns = (&mut columns)
.take_while(|column| column.len().ilog2() == log_size)
.collect_vec();

layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns));
}
Self { layers }
}

/// Decommits to columns on the given queries.
/// Queries are given as indices to the largest column.
///
/// # Arguments
///
/// * `queries` - A vector of query indices to the largest column.
///
/// # Returns
///
/// A `Decommitment` struct containing the witness.
pub fn decommit(&self, mut queries: Vec<usize>) -> MerkleDecommitment<H> {
let mut witness = Vec::new();
for layer in &self.layers[..self.layers.len() - 1] {
let mut queries_iter = queries.into_iter().peekable();

// Propagate queries and hashes to the next layer.
let mut next_queries = Vec::new();
while let Some(query) = queries_iter.next() {
next_queries.push(query / 2);
if queries_iter.next_if_eq(&(query ^ 1)).is_some() {
continue;
}
witness.push(layer.at(query ^ 1));
}
queries = next_queries;
}
MerkleDecommitment { witness }
}

pub fn root(&self) -> H::Hash {
self.layers.last().unwrap().at(0)
}
}

#[derive(Debug)]
pub struct MerkleDecommitment<H: MerkleHasher> {
pub witness: Vec<H::Hash>,
}
Loading
Loading