Skip to content

Commit

Permalink
implement merkle proving for Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ralexstokes committed Mar 30, 2024
1 parent aaf93a2 commit 96c21e2
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 77 deletions.
2 changes: 1 addition & 1 deletion ssz-rs/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ where

impl<T, const N: usize> GeneralizedIndexable for [T; N]
where
T: GeneralizedIndexable,
T: SimpleSerialize,
{
fn chunk_count() -> usize {
(N * T::item_length() + 31) / 32
Expand Down
8 changes: 4 additions & 4 deletions ssz-rs/src/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
de::{Deserialize, DeserializeError},
lib::*,
merkleization::{
proofs::{prove_primitive, ProofAndWitness, Prove},
GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
proofs::{prove_primitive, Prove, Prover},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
},
ser::{Serialize, SerializeError},
Serializable, SimpleSerialize,
Expand Down Expand Up @@ -63,8 +63,8 @@ impl GeneralizedIndexable for bool {
}

impl Prove for bool {
fn prove(&mut self, index: GeneralizedIndex) -> Result<ProofAndWitness, MerkleizationError> {
prove_primitive(self, index)
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
}
}

Expand Down
2 changes: 1 addition & 1 deletion ssz-rs/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ where

impl<T, const N: usize> GeneralizedIndexable for List<T, N>
where
T: SimpleSerialize + GeneralizedIndexable,
T: SimpleSerialize,
{
fn chunk_count() -> usize {
(N * T::item_length() + 31) / 32
Expand Down
113 changes: 72 additions & 41 deletions ssz-rs/src/merkleization/merkleize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::{
lib::*,
merkleization::{MerkleizationError as Error, Node, BYTES_PER_CHUNK},
ser::Serialize,
GeneralizedIndex,
};
use alloy_primitives::hex;
use sha2::{Digest, Sha256};

/// Types that can provide the root of their corresponding Merkle tree following the SSZ spec.
Expand Down Expand Up @@ -43,7 +45,7 @@ where
Ok(buffer)
}

fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) {
pub fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) {
hasher.update(a);
hasher.update(b);
out.copy_from_slice(&hasher.finalize_reset());
Expand Down Expand Up @@ -226,18 +228,86 @@ pub(crate) fn elements_to_chunks<'a, T: HashTreeRoot + 'a>(
Ok(chunks)
}

pub struct Tree(Vec<u8>);

impl Index<GeneralizedIndex> for Tree {
type Output = [u8];

fn index(&self, index: GeneralizedIndex) -> &Self::Output {
let start = (index - 1) * BYTES_PER_CHUNK;
let end = index * BYTES_PER_CHUNK;
&self.0[start..end]
}
}

impl std::fmt::Debug for Tree {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.chunks(BYTES_PER_CHUNK).map(hex::encode)).finish()
}
}

// Return the full Merkle tree of the `chunks`.
// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0`
// Invariant: `leaf_count.next_power_of_two() == leaf_count`
// NOTE: naive implementation, can make much more efficient
pub fn compute_merkle_tree(
hasher: &mut Sha256,
chunks: &[u8],
leaf_count: usize,
) -> Result<Tree, Error> {
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
debug_assert!(leaf_count.next_power_of_two() == leaf_count);

// SAFETY: checked subtraction is unnecessary,
// as leaf_count != 0 (0.next_power_of_two() == 1); qed
let node_count = 2 * leaf_count - 1;
// SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed
let interior_count = node_count - leaf_count;
let leaf_start = interior_count * BYTES_PER_CHUNK;

let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK];
buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks);

for i in (1..node_count).rev().step_by(2) {
// SAFETY: checked subtraction is unnecessary, as i >= 1; qed
let parent_index = (i - 1) / 2;
let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK
// = ((i + 3) / 2) * BYTES_PER_CHUNK
// and
// i >= 1
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
// NOTE: children.len() == 2 * BYTES_PER_CHUNK
let (parent, children) = focus.split_at_mut(children_index);
let (left, right) = children.split_at(BYTES_PER_CHUNK);
hash_nodes(hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
Ok(Tree(buffer))
}

#[cfg(test)]
mod tests {
use super::*;
use crate as ssz_rs;
use crate::prelude::*;
use crate::{merkleization::default_generalized_index, prelude::*};

macro_rules! hex {
($input:expr) => {
hex::decode($input).unwrap()
};
}

// Return the root of the Merklization of a binary tree formed from `chunks`.
fn merkleize_chunks(chunks: &[u8], leaf_count: usize) -> Result<Node, Error> {
let mut hasher = Sha256::new();
let tree = compute_merkle_tree(&mut hasher, chunks, leaf_count)?;
let root_index = default_generalized_index();
Ok(tree[root_index].try_into().expect("can produce a single root chunk"))
}

#[test]
fn test_packing_basic_types_simple() {
let b = true;
Expand Down Expand Up @@ -286,45 +356,6 @@ mod tests {
assert_eq!(result, expected);
}

// Return the root of the Merklization of a binary tree formed from `chunks`.
// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0`
// Invariant: `leaf_count.next_power_of_two() == leaf_count`
// NOTE: naive implementation, can make much more efficient
fn merkleize_chunks(chunks: &[u8], leaf_count: usize) -> Result<Node, MerkleizationError> {
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
debug_assert!(leaf_count.next_power_of_two() == leaf_count);

// SAFETY: checked subtraction is unnecessary,
// as leaf_count != 0 (0.next_power_of_two() == 1); qed
let node_count = 2 * leaf_count - 1;
// SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed
let interior_count = node_count - leaf_count;
let leaf_start = interior_count * BYTES_PER_CHUNK;

let mut hasher = Sha256::new();
let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK];
buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks);

for i in (1..node_count).rev().step_by(2) {
// SAFETY: checked subtraction is unnecessary, as i >= 1; qed
let parent_index = (i - 1) / 2;
let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK
// = ((i + 3) / 2) * BYTES_PER_CHUNK
// and
// i >= 1
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
// NOTE: children.len() == 2 * BYTES_PER_CHUNK
let (parent, children) = focus.split_at_mut(children_index);
let (left, right) = children.split_at(BYTES_PER_CHUNK);
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
Ok(buffer[..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk"))
}

#[test]
fn test_naive_merkleize_chunks() {
let chunks = vec![0u8; 2 * BYTES_PER_CHUNK];
Expand Down
4 changes: 4 additions & 0 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub enum MerkleizationError {
InvalidPathElement(PathElement),
/// Signals an invalid path when walking a `GeneralizedIndexable` type
InvalidPath(Vec<PathElement>),
InvalidDepth,
InvalidIndex,
}

impl From<SerializeError> for MerkleizationError {
Expand All @@ -49,6 +51,8 @@ impl Display for MerkleizationError {
Self::InvalidGeneralizedIndex => write!(f, "invalid generalized index"),
Self::InvalidPathElement(element) => write!(f, "invalid path element {element:?}"),
Self::InvalidPath(path) => write!(f, "invalid path {path:?}"),
Self::InvalidDepth => write!(f, "error computing depth for proof"),
Self::InvalidIndex => write!(f, "error computing index for proof"),
}
}
}
Expand Down
74 changes: 65 additions & 9 deletions ssz-rs/src/merkleization/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,67 @@ use sha2::{Digest, Sha256};

pub type ProofAndWitness = (Proof, Node);

fn get_depth(i: GeneralizedIndex) -> Result<u32, Error> {
log_2(i).ok_or(Error::InvalidGeneralizedIndex)
}

fn get_index(i: GeneralizedIndex, depth: u32) -> usize {
i % 2usize.pow(depth)
}

pub fn get_subtree_index(i: GeneralizedIndex) -> Result<usize, Error> {
let i_log2 = log_2(i).ok_or(Error::InvalidGeneralizedIndex)?;
Ok(i % 2usize.pow(i_log2))
let depth = get_depth(i)?;
Ok(get_index(i, depth))
}

#[derive(Debug)]
pub struct Prover {
pub(crate) hasher: Sha256,
pub proof: Proof,
pub witness: Node,
}

impl Prover {
pub fn set_leaf(&mut self, leaf: Node) {
self.proof.leaf = leaf;
}

// Adds a node to the Merkle proof's branch.
// Assumes nodes are provided going from the bottom of the tree to the top.
pub fn extend_branch(&mut self, node: Node) {
self.proof.branch.push(node)
}

pub fn set_witness(&mut self, witness: Node) {
self.witness = witness;
}

pub fn compute_depth_and_index(&self, i: GeneralizedIndex) -> Result<(u32, usize), Error> {
let depth = get_depth(i)?;
Ok((depth, get_index(i, depth)))
}
}

impl From<Prover> for ProofAndWitness {
fn from(value: Prover) -> Self {
(value.proof, value.witness)
}
}

impl From<GeneralizedIndex> for Prover {
fn from(index: GeneralizedIndex) -> Self {
Self {
hasher: Sha256::new(),
proof: Proof { leaf: Default::default(), branch: vec![], index },
witness: Default::default(),
}
}
}

/// Types that can produce Merkle proofs against themselves given a `GeneralizedIndex`.
pub trait Prove {
/// Provide a Merkle proof of the node in this type's merkle tree corresponding to the `index`.
fn prove(&mut self, index: GeneralizedIndex) -> Result<ProofAndWitness, Error>;
fn prove(&mut self, prover: &mut Prover) -> Result<(), Error>;
}

/// Produce a Merkle proof (and corresponding witness) for the type `T` at the given `path` relative
Expand All @@ -28,12 +80,14 @@ pub fn prove<T: GeneralizedIndexable + Prove>(
path: Path,
) -> Result<ProofAndWitness, Error> {
let index = T::generalized_index(path)?;
data.prove(index)
let mut prover = index.into();
data.prove(&mut prover)?;
Ok(prover.into())
}

/// Contains data necessary to verify `leaf` was included under some witness "root" node
/// at the generalized position `index`.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub struct Proof {
pub leaf: Node,
pub branch: Vec<Node>,
Expand All @@ -51,15 +105,17 @@ impl Proof {

pub fn prove_primitive<T: HashTreeRoot + ?Sized>(
data: &mut T,
index: GeneralizedIndex,
) -> Result<ProofAndWitness, Error> {
prover: &mut Prover,
) -> Result<(), Error> {
let index = prover.proof.index;
if index != default_generalized_index() {
return Err(Error::InvalidGeneralizedIndex)
}

let root = data.hash_tree_root()?;
let proof = Proof { leaf: root, branch: vec![], index };
Ok((proof, root))
prover.set_leaf(root);
prover.set_witness(root);
Ok(())
}

pub fn is_valid_merkle_branch_for_generalized_index<T: AsRef<[u8]>>(
Expand Down
15 changes: 6 additions & 9 deletions ssz-rs/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::{
lib::*,
merkleization::{
pack_bytes,
proofs::{prove_primitive, ProofAndWitness, Prove},
GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
proofs::{prove_primitive, Prove, Prover},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
},
ser::{Serialize, SerializeError},
Serializable, SimpleSerialize, BITS_PER_BYTE,
Expand Down Expand Up @@ -76,11 +76,8 @@ macro_rules! define_uint {
}

impl Prove for $uint {
fn prove(
&mut self,
index: GeneralizedIndex,
) -> Result<ProofAndWitness, MerkleizationError> {
prove_primitive(self, index)
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
}
}

Expand Down Expand Up @@ -156,8 +153,8 @@ impl GeneralizedIndexable for U256 {
}

impl Prove for U256 {
fn prove(&mut self, index: GeneralizedIndex) -> Result<ProofAndWitness, MerkleizationError> {
prove_primitive(self, index)
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
}
}

Expand Down
Loading

0 comments on commit 96c21e2

Please sign in to comment.