Skip to content

Commit

Permalink
Use simple merkle tree
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Mar 25, 2024
1 parent 43bb6bf commit 67cf290
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 98 deletions.
6 changes: 6 additions & 0 deletions src/commitment_scheme/blake2_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ impl From<&[u8]> for Blake2sHash {
}
}

impl From<[u8; 32]> for Blake2sHash {
fn from(value: [u8; 32]) -> Self {
Self(value)
}
}

impl AsRef<[u8]> for Blake2sHash {
fn as_ref(&self) -> &[u8] {
&self.0
Expand Down
29 changes: 14 additions & 15 deletions src/commitment_scheme/blake2_merkle.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use itertools::Itertools;
use num_traits::Zero;

use super::blake2_hash::{Blake2sHash, Blake2sHasher};
use super::blake2s_ref::compress;
use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

pub struct Blake2Hasher;
impl MerkleHasher for Blake2Hasher {
type Hash = [u32; 8];

impl MerkleHasher for Blake2sHasher {
fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
column_values: &[BaseField],
Expand All @@ -33,19 +31,19 @@ impl MerkleHasher for Blake2Hasher {
for chunk in padded_values.array_chunks::<16>() {
state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0);
}
state
unsafe { std::mem::transmute(state) }
}
}

impl MerkleOps<Blake2Hasher> for CPUBackend {
impl MerkleOps<Blake2sHasher> for CPUBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<[u32; 8]>>,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<[u32; 8]> {
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2Hasher::hash_node(
Blake2sHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
Expand All @@ -61,17 +59,18 @@ mod tests {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::commitment_scheme::blake2_merkle::Blake2Hasher;
use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sHasher;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::commitment_scheme::verifier::{MerkleTreeVerifier, MerkleVerificationError};
use crate::core::backend::CPUBackend;
use crate::core::fields::m31::BaseField;

type TestData = (
Vec<usize>,
MerkleDecommitment<Blake2Hasher>,
MerkleDecommitment<Blake2sHasher>,
Vec<(u32, Vec<BaseField>)>,
MerkleTreeVerifier<Blake2Hasher>,
MerkleTreeVerifier<Blake2sHasher>,
);
fn prepare_merkle() -> TestData {
const N_COLS: usize = 400;
Expand All @@ -92,7 +91,7 @@ mod tests {
.collect_vec()
})
.collect_vec();
let merkle = MerkleProver::<CPUBackend, Blake2Hasher>::commit(cols.iter().collect_vec());
let merkle = MerkleProver::<CPUBackend, Blake2sHasher>::commit(cols.iter().collect_vec());

let queries = (0..N_QUERIES)
.map(|_| rng.gen_range(0..(1 << max_log_size)))
Expand Down Expand Up @@ -128,7 +127,7 @@ mod tests {
#[test]
fn test_merkle_invalid_witness() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness[20] = [0; 8];
decommitment.witness[20] = Blake2sHash::from([0u8; 32]);

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand Down Expand Up @@ -183,7 +182,7 @@ mod tests {
#[test]
fn test_merkle_witness_too_long() {
let (queries, mut decommitment, values, verifier) = prepare_merkle();
decommitment.witness.push([0; 8]);
decommitment.witness.push(Blake2sHash::from(&[0u8; 32][..]));

assert_eq!(
verifier.verify(queries, values, decommitment).unwrap_err(),
Expand Down
4 changes: 2 additions & 2 deletions src/commitment_scheme/ops.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::hasher::Hasher;
use crate::core::backend::{Col, ColumnOps};
use crate::core::fields::m31::BaseField;

Expand All @@ -8,8 +9,7 @@ use crate::core::fields::m31::BaseField;
/// children hashes.
/// At each layer, the tree may have multiple columns of the same length as the layer.
/// Each node in that layer contains one value from each column.
pub trait MerkleHasher {
type Hash: Clone + Eq + std::fmt::Debug;
pub trait MerkleHasher: Hasher {
/// Hashes a single Merkle node. See [MerkleHasher] for more details.
fn hash_node(
children_hashes: Option<(Self::Hash, Self::Hash)>,
Expand Down
8 changes: 3 additions & 5 deletions src/commitment_scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::iter::Peekable;
use itertools::Itertools;
use thiserror::Error;

use super::hasher::Hasher;
use super::ops::MerkleHasher;
use super::prover::MerkleDecommitment;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -79,7 +80,7 @@ impl<H: MerkleHasher> MerkleTreeVerifier<H> {
/// A helper struct for verifying a [MerkleDecommitment].
struct MerkleVerifier<H: MerkleHasher> {
/// A queue for consuming the next hash witness from the decommitment.
witness: std::vec::IntoIter<<H as MerkleHasher>::Hash>,
witness: std::vec::IntoIter<<H as Hasher>::Hash>,
/// A queue for consuming the next claimed values for each column.
column_values: Peekable<std::vec::IntoIter<(u32, Vec<BaseField>)>>,
/// A queue for consuming the next claimed values for each column in the current layer.
Expand Down Expand Up @@ -190,10 +191,7 @@ impl<H: MerkleHasher> MerkleVerifier<H> {
}
}

type ChildrenHashesAtQuery<H> = Option<(
Option<<H as MerkleHasher>::Hash>,
Option<<H as MerkleHasher>::Hash>,
)>;
type ChildrenHashesAtQuery<H> = Option<(Option<<H as Hasher>::Hash>, Option<<H as Hasher>::Hash>)>;

#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
pub enum MerkleVerificationError {
Expand Down
85 changes: 25 additions & 60 deletions src/core/commitment_scheme/prover.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cmp::Reverse;
use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;

Expand All @@ -20,9 +20,7 @@ use super::super::prover::{
use super::super::ColumnVec;
use super::utils::TreeVec;
use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::merkle_input::{MerkleTreeColumnLayout, MerkleTreeInput};
use crate::commitment_scheme::mixed_degree_decommitment::MixedDecommitment;
use crate::commitment_scheme::mixed_degree_merkle_tree::MixedDegreeMerkleTree;
use crate::commitment_scheme::prover::{MerkleDecommitment, MerkleProver};
use crate::core::channel::Channel;

type MerkleHasher = Blake2sHasher;
Expand All @@ -48,7 +46,7 @@ impl CommitmentSchemeProver {
}

pub fn roots(&self) -> TreeVec<Blake2sHash> {
self.trees.as_ref().map(|tree| tree.root())
self.trees.as_ref().map(|tree| tree.commitment.root())
}

pub fn polynomials(&self) -> TreeVec<ColumnVec<&CPUCirclePoly>> {
Expand Down Expand Up @@ -111,14 +109,13 @@ impl CommitmentSchemeProver {

// Decommit the FRI queries on the merkle trees.
let decommitment_results = self.trees.as_ref().map(|tree| {
let queries = tree
.polynomials
let max_log_size = tree
.evaluations
.iter()
.map(|poly| {
fri_query_domains[&(poly.log_size() + self.log_blowup_factor)].flatten()
})
.collect();
tree.decommit(queries)
.map(|e| e.domain.log_size())
.max()
.unwrap();
tree.decommit(fri_query_domains[&max_log_size].flatten())
});

let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone());
Expand All @@ -137,7 +134,7 @@ impl CommitmentSchemeProver {
#[derive(Debug)]
pub struct CommitmentSchemeProof {
pub proved_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MixedDecommitment<BaseField, MerkleHasher>>,
pub decommitments: TreeVec<MerkleDecommitment<Blake2sHasher>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub proof_of_work: ProofOfWorkProof,
pub fri_proof: FriProof<MerkleHasher>,
Expand All @@ -148,8 +145,7 @@ pub struct CommitmentSchemeProof {
pub struct CommitmentTreeProver {
pub polynomials: ColumnVec<CPUCirclePoly>,
pub evaluations: ColumnVec<CPUCircleEvaluation<BaseField, BitReversedOrder>>,
pub commitment: MixedDegreeMerkleTree<BaseField, Blake2sHasher>,
column_layout: MerkleTreeColumnLayout,
pub commitment: MerkleProver<CPUBackend, Blake2sHasher>,
}

impl CommitmentTreeProver {
Expand All @@ -160,71 +156,40 @@ impl CommitmentTreeProver {
) -> Self {
let evaluations = polynomials
.iter()
.sorted_by_key(|poly| Reverse(poly.log_size()))
.map(|poly| {
poly.evaluate(
CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(),
)
})
.collect_vec();

let mut merkle_input = MerkleTreeInput::new();
const LOG_N_BASE_FIELD_ELEMENTS_IN_SACK: u32 = 4;

for eval in evaluations.iter() {
// The desired depth for a column of log_length n is such that Blake2s hashes are
// filled(64B). Explicitly: There are 2^(d-1) hash 'sacks' at depth d,
// hence, with elements of 4 bytes, 2^(d-1) = 2^n / 16, => d = n-3.
let inject_depth = std::cmp::max::<i32>(
eval.domain.log_size() as i32 - (LOG_N_BASE_FIELD_ELEMENTS_IN_SACK as i32 - 1),
1,
);
merkle_input.insert_column(inject_depth as usize, &eval.values);
}

let (tree, root) =
MixedDegreeMerkleTree::<BaseField, Blake2sHasher>::commit_default(&merkle_input);
channel.mix_digest(root);

let column_layout = merkle_input.column_layout();
let tree = MerkleProver::<CPUBackend, Blake2sHasher>::commit(
evaluations.iter().map(|e| &e.values).collect_vec(),
);
channel.mix_digest(tree.root());

CommitmentTreeProver {
polynomials,
evaluations,
commitment: tree,
column_layout,
}
}

/// Decommits the merkle tree on the given query positions.
fn decommit(
&self,
queries: ColumnVec<Vec<usize>>,
) -> (
ColumnVec<Vec<BaseField>>,
MixedDecommitment<BaseField, Blake2sHasher>,
) {
let values = zip(&self.evaluations, &queries)
.map(|(column, column_queries)| column_queries.iter().map(|q| column[*q]).collect())
.collect();

// Rebuild the merkle input for now.
// TODO(Ohad): change after tree refactor. Consider removing the input struct and have the
// decommitment take queries and columns only.
let eval_vec = self
queries: Vec<usize>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<Blake2sHasher>) {
// TODO(spapini): Queries should be the queries to the largest layer.
// When we have more than one component, we should extract the values correctly
let values = self
.evaluations
.iter()
.map(|eval| &eval.values[..])
.collect_vec();
let input = self.column_layout.build_input(&eval_vec);
let decommitment = self.commitment.decommit(&input, &queries);
(values, decommitment)
}
}

impl Deref for CommitmentTreeProver {
type Target = MixedDegreeMerkleTree<BaseField, Blake2sHasher>;
.map(|c| queries.iter().map(|p| c[*p]).collect())
.collect();

fn deref(&self) -> &Self::Target {
&self.commitment
let decommitment = self.commitment.decommit(queries);
(values, decommitment)
}
}
49 changes: 33 additions & 16 deletions src/core/commitment_scheme/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Reverse;
use std::iter::zip;

use itertools::Itertools;
Expand All @@ -17,7 +18,8 @@ use super::super::queries::SparseSubCircleDomain;
use super::utils::TreeVec;
use super::CommitmentSchemeProof;
use crate::commitment_scheme::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::commitment_scheme::mixed_degree_decommitment::MixedDecommitment;
use crate::commitment_scheme::prover::MerkleDecommitment;
use crate::commitment_scheme::verifier::MerkleTreeVerifier;
use crate::core::channel::Channel;
use crate::core::prover::VerificationError;
use crate::core::ColumnVec;
Expand Down Expand Up @@ -83,15 +85,20 @@ impl CommitmentSchemeVerifier {
let merkle_verification_result = self
.trees
.as_ref()
.zip(&proof.decommitments)
.map(|(tree, decommitment)| {
.zip(proof.decommitments)
.zip(proof.queried_values.clone())
.map(|((tree, decommitment), queried_values)| {
// TODO(spapini): Also verify proved_values here.
let queries = tree
.log_sizes
.iter()
.map(|log_size| fri_query_domains[&(log_size + LOG_BLOWUP_FACTOR)].flatten())
.collect_vec();
tree.verify(decommitment, &queries)
// Assuming columns are of equal lengths, replicate queries for all columns.
// TOOD(AlonH): remove this assumption.
tree.verify(
decommitment,
queried_values,
// Queries to the largest size.
fri_query_domains[&(tree.log_sizes[0] + LOG_BLOWUP_FACTOR)]
.flatten()
.clone(),
)
})
.iter()
.all(|x| *x);
Expand Down Expand Up @@ -180,13 +187,23 @@ impl CommitmentTreeVerifier {

pub fn verify(
&self,
decommitment: &MixedDecommitment<BaseField, Blake2sHasher>,
queries: &[Vec<usize>],
decommitment: MerkleDecommitment<Blake2sHasher>,
values: Vec<Vec<BaseField>>,
queries: Vec<usize>,
) -> bool {
decommitment.verify(
self.commitment,
queries,
decommitment.queried_values.iter().copied(),
)
let values = self
.log_sizes
.iter()
.map(|log_size| *log_size + LOG_BLOWUP_FACTOR)
.zip(values)
.sorted_by_key(|(log_size, _)| Reverse(*log_size))
.collect_vec();

// TODO(spapini): Propagate error.
MerkleTreeVerifier {
root: self.commitment,
}
.verify(queries, values, decommitment)
.is_ok()
}
}

0 comments on commit 67cf290

Please sign in to comment.