Skip to content

Commit

Permalink
Use simple merkle tree
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 25, 2024
1 parent c1ab684 commit 8cc9919
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 94 deletions.
29 changes: 14 additions & 15 deletions src/commitment_scheme/blake2_merkle.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use itertools::Itertools;
use num_traits::Zero;

use super::blake2_hash::{Blake2sHash, Blake2sHasher};
use super::blake2s_ref::compress;
use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

pub struct Blake2Hasher;
impl MerkleHasher for Blake2Hasher {
type Hash = [u32; 8];

impl MerkleHasher for Blake2sHasher {
fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
Expand All @@ -33,19 +31,19 @@ impl MerkleHasher for Blake2Hasher {
for chunk in padded_values.array_chunks::<16>() {
state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
}
state
unsafe { std::mem::transmute(state) }
}
}

impl MerkleOps<Blake2Hasher> for CPUBackend {
impl MerkleOps<Blake2sHasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<[u32; 8]>>,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<[u32; 8]> {
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2Hasher::hash_node(
Blake2sHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
Expand All @@ -61,17 +59,18 @@ mod tests {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sHasher;
use crate::commitment_scheme::prover::{Decommitment, MerkleProver};
use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

type TestData = (
Vec<usize>,
Decommitment<Blake2Hasher>,
Decommitment<Blake2sHasher>,
Vec<(u32, Vec<BaseField>)>,
MerkleTreeVerifier<Blake2Hasher>,
MerkleTreeVerifier<Blake2sHasher>,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
Expand All @@ -92,7 +91,7 @@ mod tests {
.collect_vec()
})
.collect_vec();
let merkle = MerkleProver::<CPUBackend, Blake2Hasher>::commit(cols.iter().collect_vec());
let merkle = MerkleProver::<CPUBackend, Blake2sHasher>::commit(cols.iter().collect_vec());

let queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << max_log_size)))
Expand Down Expand Up @@ -128,7 +127,7 @@ mod tests {
#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness[20] = [0; 8];
decommitment.witness[20] = Blake2sHash::from(&[0u8; 32][..]);

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -183,7 +182,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.push([0; 8]);
decommitment.witness.push(Blake2sHash::from(&[0u8; 32][..]));

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand Down
4 changes: 2 additions & 2 deletions src/commitment_scheme/ops.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::hasher::Hasher;
use crate::core::backend::{Col, ColumnOps};
use crate::core::fields::m31::BaseField;

pub trait MerkleHasher {
type Hash: Clone + Eq + std::fmt::Debug;
pub trait MerkleHasher: Hasher {
/// Hashes a single Merkle node.
/// The node may or may not need to hash 2 hashes from the previous layer - depending if it is a
/// leaf or not.
Expand Down
1 change: 1 addition & 0 deletions src/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl<B: MerkleOps<H>, H: MerkleHasher> MerkleProver<B, H> {
}
}

#[derive(Debug)]
pub struct Decommitment<H: MerkleHasher> {
pub witness: Vec<H::Hash>,
}
10 changes: 5 additions & 5 deletions src/commitment_scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use std::iter::Peekable;
use itertools::Itertools;
use thiserror::Error;

use super::hasher::Hasher;
use super::ops::MerkleHasher;
use super::prover::Decommitment;
use crate::core::fields::m31::BaseField;

// TODO(spapini): This struct is not necessary. Make it a function on decommitment?
pub struct MerkleTreeVerifier<H: MerkleHasher> {
pub root: H::Hash,
}
Expand Down Expand Up @@ -52,7 +54,7 @@ impl<H: MerkleHasher> MerkleTreeVerifier<H> {
}

struct MerkleVerifier<H: MerkleHasher> {
witness: std::vec::IntoIter<<H as MerkleHasher>::Hash>,
witness: std::vec::IntoIter<<H as Hasher>::Hash>,
column_values: Peekable<std::vec::IntoIter<(u32, Vec<BaseField>)>>,
layer_column_values: Vec<std::vec::IntoIter<BaseField>>,
}
Expand All @@ -62,6 +64,7 @@ impl<H: MerkleHasher> MerkleVerifier<H> {
queries: Vec<usize>,
) -> Result<H::Hash, MerkleVerificationError> {
let max_log_size = self.column_values.peek().unwrap().0;
assert!(*queries.iter().max().unwrap() < 1 << max_log_size);

// A sequence of queries to the current layer.
// Each query is a pair of the query index and the known hashes of the children, if any.
Expand Down Expand Up @@ -144,10 +147,7 @@ impl<H: MerkleHasher> MerkleVerifier<H> {
}
}

type ChildrenHashesAtQuery<H> = Option<(
Option<<H as MerkleHasher>::Hash>,
Option<<H as MerkleHasher>::Hash>,
)>;
type ChildrenHashesAtQuery<H> = Option<(Option<<H as Hasher>::Hash>, Option<<H as Hasher>::Hash>)>;

#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
pub enum MerkleVerificationError {
Expand Down
68 changes: 13 additions & 55 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cmp::Reverse;
use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;

Expand All @@ -20,9 +20,7 @@ use super::super::prover::{
use super::super::ColumnVec;
use super::utils::TreeVec;
use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::merkle_input::{MerkleTreeColumnLayout, MerkleTreeInput};
use crate::commitment_scheme::mixed_degree_decommitment::MixedDecommitment;
use crate::commitment_scheme::mixed_degree_merkle_tree::MixedDegreeMerkleTree;
use crate::commitment_scheme::prover::{Decommitment, MerkleProver};
use crate::core::channel::Channel;

type MerkleHasher = Blake2sHasher;
Expand All @@ -48,7 +46,7 @@ impl CommitmentSchemeProver {
}

pub fn roots(&self) -> TreeVec<Blake2sHash> {
self.trees.as_ref().map(|tree| tree.root())
self.trees.as_ref().map(|tree| tree.commitment.root())
}

pub fn polynomials(&self) -> TreeVec<ColumnVec<&CPUCirclePoly>> {
Expand Down Expand Up @@ -132,7 +130,7 @@ impl CommitmentSchemeProver {
#[derive(Debug)]
pub struct CommitmentSchemeProof {
pub proved_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MixedDecommitment<BaseField, MerkleHasher>>,
pub decommitments: TreeVec<Decommitment<Blake2sHasher>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub proof_of_work: ProofOfWorkProof,
pub fri_proof: FriProof<MerkleHasher>,
Expand All @@ -143,8 +141,7 @@ pub struct CommitmentSchemeProof {
pub struct CommitmentTreeProver {
pub polynomials: ColumnVec<CPUCirclePoly>,
pub evaluations: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
pub commitment: MixedDegreeMerkleTree<BaseField, Blake2sHasher>,
column_layout: MerkleTreeColumnLayout,
pub commitment: MerkleProver<CPUBackend, Blake2sHasher>,
}

impl CommitmentTreeProver {
Expand All @@ -155,39 +152,23 @@ impl CommitmentTreeProver {
) -> Self {
let evaluations = polynomials
.iter()
.sorted_by_key(|eval| Reverse(eval.log_size()))
.map(|poly| {
poly.evaluate(
CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(),
)
})
.collect_vec();

let mut merkle_input = MerkleTreeInput::new();
const LOG_N_BASEFIELD_ELEMENTS_IN_SACK: u32 = 4;

// The desired depth for column of log_length n is such that Blake2s hashes are filled(64B).
// Explicitly: There are 2^(d-1) hash 'sacks' at depth d, hence, with elements of 4 bytes,
// 2^(d-1) = 2^n / 16, => d = n-3.
// Assuming rectangle trace, all columns go to the same depth.
// TOOD(AlonH): remove this assumption.
let inject_depth = std::cmp::max::<i32>(
evaluations[0].len().ilog2() as i32 - (LOG_N_BASEFIELD_ELEMENTS_IN_SACK as i32 - 1),
1,
let tree = MerkleProver::<CPUBackend, Blake2sHasher>::commit(
evaluations.iter().map(|e| &e.values).collect_vec(),
);
for column in evaluations.iter().map(|eval| &eval.values) {
merkle_input.insert_column(inject_depth as usize, column);
}
let (tree, root) =
MixedDegreeMerkleTree::<BaseField, Blake2sHasher>::commit_default(&merkle_input);
channel.mix_digest(root);

let column_layout = merkle_input.column_layout();
channel.mix_digest(tree.root());

CommitmentTreeProver {
polynomials,
evaluations,
commitment: tree,
column_layout,
}
}

Expand All @@ -196,39 +177,16 @@ impl CommitmentTreeProver {
fn decommit(
&self,
queries: Vec<usize>,
) -> (
ColumnVec<Vec<BaseField>>,
MixedDecommitment<BaseField, Blake2sHasher>,
) {
) -> (ColumnVec<Vec<BaseField>>, Decommitment<Blake2sHasher>) {
// TODO(spapini): Queries should be the queries to the largest layer.
// When we have more than one component, we should extract the values correctly
let values = self
.evaluations
.iter()
.map(|c| queries.iter().map(|p| c[*p]).collect())
.collect();
// Assuming rectangle trace, queries should be similar for all columns.
// TOOD(AlonH): remove this assumption.
let queries = std::iter::repeat(queries.to_vec())
.take(self.evaluations.len())
.collect_vec();

// Rebuild the merkle input for now.
// TODO(Ohad): change after tree refactor. Consider removing the input struct and have the
// decommitment take queries and columns only.
let eval_vec = self
.evaluations
.iter()
.map(|eval| &eval.values[..])
.collect_vec();
let input = self.column_layout.build_input(&eval_vec);
let decommitment = self.commitment.decommit(&input, &queries);
let decommitment = self.commitment.decommit(queries);
(values, decommitment)
}
}

impl Deref for CommitmentTreeProver {
type Target = MixedDegreeMerkleTree<BaseField, Blake2sHasher>;

fn deref(&self) -> &Self::Target {
&self.commitment
}
}
45 changes: 28 additions & 17 deletions src/core/commitment_scheme/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Reverse;
use std::iter::zip;

use itertools::Itertools;
Expand All @@ -17,7 +18,8 @@ use super::super::queries::SparseSubCircleDomain;
use super::utils::TreeVec;
use super::CommitmentSchemeProof;
use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::mixed_degree_decommitment::MixedDecommitment;
use crate::commitment_scheme::prover::Decommitment;
use crate::commitment_scheme::verifier::MerkleTreeVerifier;
use crate::core::channel::Channel;
use crate::core::prover::VerificationError;
use crate::core::ColumnVec;
Expand Down Expand Up @@ -83,20 +85,19 @@ impl CommitmentSchemeVerifier {
if !self
.trees
.as_ref()
.zip(&proof.decommitments)
.map(|(tree, decommitment)| {
.zip(proof.decommitments)
.zip(proof.queried_values.clone())
.map(|((tree, decommitment), queried_values)| {
// TODO(spapini): Also verify proved_values here.
// Assuming columns are of equal lengths, replicate queries for all columns.
// TOOD(AlonH): remove this assumption.
tree.verify(
decommitment,
&std::iter::repeat(
fri_query_domains[&(tree.log_sizes[0] + LOG_BLOWUP_FACTOR)]
.flatten()
.clone(),
)
.take(tree.log_sizes.len())
.collect_vec(),
queried_values,
// Queries to the largest size.
fri_query_domains[&(tree.log_sizes[0] + LOG_BLOWUP_FACTOR)]
.flatten()
.clone(),
)
})
.iter()
Expand Down Expand Up @@ -186,13 +187,23 @@ impl CommitmentTreeVerifier {

pub fn verify(
&self,
decommitment: &MixedDecommitment<BaseField, Blake2sHasher>,
queries: &[Vec<usize>],
decommitment: Decommitment<Blake2sHasher>,
values: Vec<Vec<BaseField>>,
queries: Vec<usize>,
) -> bool {
decommitment.verify(
self.commitment,
queries,
decommitment.queried_values.iter().copied(),
)
let values = self
.log_sizes
.iter()
.map(|log_size| *log_size + LOG_BLOWUP_FACTOR)
.zip(values)
.sorted_by_key(|(log_size, _)| Reverse(*log_size))
.collect_vec();

// TODO(spapini): Propagate error.
MerkleTreeVerifier {
root: self.commitment,
}
.verify(queries, values, decommitment)
.is_ok()
}
}

0 comments on commit 8cc9919

Please sign in to comment.