diff --git a/ssz-rs/src/boolean.rs b/ssz-rs/src/boolean.rs index a4a088ec..eabaa669 100644 --- a/ssz-rs/src/boolean.rs +++ b/ssz-rs/src/boolean.rs @@ -2,8 +2,8 @@ use crate::{ de::{Deserialize, DeserializeError}, lib::*, merkleization::{ - proofs::{prove_primitive, ProofAndWitness, Prove}, - GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, + proofs::{prove_primitive, ProofAndWitness, Prove, Prover}, + GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, }, ser::{Serialize, SerializeError}, Serializable, SimpleSerialize, @@ -63,8 +63,8 @@ impl GeneralizedIndexable for bool { } impl Prove for bool { - fn prove(&mut self, index: GeneralizedIndex) -> Result { - prove_primitive(self, index) + fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { + prove_primitive(self, prover) } } diff --git a/ssz-rs/src/merkleization/merkleize.rs b/ssz-rs/src/merkleization/merkleize.rs index bb34e4ce..5f82bfed 100644 --- a/ssz-rs/src/merkleization/merkleize.rs +++ b/ssz-rs/src/merkleization/merkleize.rs @@ -3,7 +3,9 @@ use crate::{ lib::*, merkleization::{MerkleizationError as Error, Node, BYTES_PER_CHUNK}, ser::Serialize, + GeneralizedIndex, }; +use alloy_primitives::hex; use sha2::{Digest, Sha256}; /// Types that can provide the root of their corresponding Merkle tree following the SSZ spec. @@ -43,7 +45,7 @@ where Ok(buffer) } -fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) { +pub fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) { hasher.update(a); hasher.update(b); out.copy_from_slice(&hasher.finalize_reset()); @@ -226,11 +228,76 @@ pub(crate) fn elements_to_chunks<'a, T: HashTreeRoot + 'a>( Ok(chunks) } +pub struct Tree(Vec); + +impl Index for Tree { + type Output = [u8]; + + fn index(&self, index: GeneralizedIndex) -> &Self::Output { + let start = (index - 1) * BYTES_PER_CHUNK; + let end = index * BYTES_PER_CHUNK; + &self.0[start..end] + } +} + +impl std::fmt::Debug for Tree { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for chunk in self.0.chunks(BYTES_PER_CHUNK) { + let node = hex::encode(chunk); + f.write_str(&node)?; + f.write_str(",\n")?; + } + Ok(()) + } +} + +// Return the full Merkle tree of the `chunks`. +// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0` +// Invariant: `leaf_count.next_power_of_two() == leaf_count` +// NOTE: naive implementation, can make much more efficient +pub fn compute_merkle_tree( + hasher: &mut Sha256, + chunks: &[u8], + leaf_count: usize, +) -> Result { + debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0); + debug_assert!(leaf_count.next_power_of_two() == leaf_count); + + // SAFETY: checked subtraction is unnecessary, + // as leaf_count != 0 (0.next_power_of_two() == 1); qed + let node_count = 2 * leaf_count - 1; + // SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed + let interior_count = node_count - leaf_count; + let leaf_start = interior_count * BYTES_PER_CHUNK; + + let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK]; + buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks); + + for i in (1..node_count).rev().step_by(2) { + // SAFETY: checked subtraction is unnecessary, as i >= 1; qed + let parent_index = (i - 1) / 2; + let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK]; + // SAFETY: checked subtraction is unnecessary: + // focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK + // = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK + // = ((i + 3) / 2) * BYTES_PER_CHUNK + // and + // i >= 1 + // so focus.len() >= 2 * BYTES_PER_CHUNK; qed + let children_index = focus.len() - 2 * BYTES_PER_CHUNK; + // NOTE: children.len() == 2 * BYTES_PER_CHUNK + let (parent, children) = focus.split_at_mut(children_index); + let (left, right) = children.split_at(BYTES_PER_CHUNK); + hash_nodes(hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); + } + Ok(Tree(buffer)) +} + #[cfg(test)] mod tests { use super::*; use crate as ssz_rs; - use crate::prelude::*; + use crate::{merkleization::default_generalized_index, prelude::*}; macro_rules! hex { ($input:expr) => { @@ -238,6 +305,14 @@ mod tests { }; } + // Return the root of the Merklization of a binary tree formed from `chunks`. + fn merkleize_chunks(chunks: &[u8], leaf_count: usize) -> Result { + let mut hasher = Sha256::new(); + let tree = compute_merkle_tree(&mut hasher, chunks, leaf_count)?; + let root_index = default_generalized_index(); + Ok(tree[root_index].try_into().expect("can produce a single root chunk")) + } + #[test] fn test_packing_basic_types_simple() { let b = true; @@ -286,45 +361,6 @@ mod tests { assert_eq!(result, expected); } - // Return the root of the Merklization of a binary tree formed from `chunks`. - // Invariant: `chunks.len() % BYTES_PER_CHUNK == 0` - // Invariant: `leaf_count.next_power_of_two() == leaf_count` - // NOTE: naive implementation, can make much more efficient - fn merkleize_chunks(chunks: &[u8], leaf_count: usize) -> Result { - debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0); - debug_assert!(leaf_count.next_power_of_two() == leaf_count); - - // SAFETY: checked subtraction is unnecessary, - // as leaf_count != 0 (0.next_power_of_two() == 1); qed - let node_count = 2 * leaf_count - 1; - // SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed - let interior_count = node_count - leaf_count; - let leaf_start = interior_count * BYTES_PER_CHUNK; - - let mut hasher = Sha256::new(); - let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK]; - buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks); - - for i in (1..node_count).rev().step_by(2) { - // SAFETY: checked subtraction is unnecessary, as i >= 1; qed - let parent_index = (i - 1) / 2; - let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK]; - // SAFETY: checked subtraction is unnecessary: - // focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK - // = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK - // = ((i + 3) / 2) * BYTES_PER_CHUNK - // and - // i >= 1 - // so focus.len() >= 2 * BYTES_PER_CHUNK; qed - let children_index = focus.len() - 2 * BYTES_PER_CHUNK; - // NOTE: children.len() == 2 * BYTES_PER_CHUNK - let (parent, children) = focus.split_at_mut(children_index); - let (left, right) = children.split_at(BYTES_PER_CHUNK); - hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); - } - Ok(buffer[..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk")) - } - #[test] fn test_naive_merkleize_chunks() { let chunks = vec![0u8; 2 * BYTES_PER_CHUNK]; diff --git a/ssz-rs/src/merkleization/mod.rs b/ssz-rs/src/merkleization/mod.rs index 3b358056..87686d0c 100644 --- a/ssz-rs/src/merkleization/mod.rs +++ b/ssz-rs/src/merkleization/mod.rs @@ -30,6 +30,8 @@ pub enum MerkleizationError { InvalidPathElement(PathElement), /// Signals an invalid path when walking a `GeneralizedIndexable` type InvalidPath(Vec), + InvalidDepth, + InvalidIndex, } impl From for MerkleizationError { @@ -49,6 +51,8 @@ impl Display for MerkleizationError { Self::InvalidGeneralizedIndex => write!(f, "invalid generalized index"), Self::InvalidPathElement(element) => write!(f, "invalid path element {element:?}"), Self::InvalidPath(path) => write!(f, "invalid path {path:?}"), + Self::InvalidDepth => write!(f, "error computing depth for proof"), + Self::InvalidIndex => write!(f, "error computing index for proof"), } } } diff --git a/ssz-rs/src/merkleization/proofs.rs b/ssz-rs/src/merkleization/proofs.rs index c92dcfd5..47565de4 100644 --- a/ssz-rs/src/merkleization/proofs.rs +++ b/ssz-rs/src/merkleization/proofs.rs @@ -10,15 +10,67 @@ use sha2::{Digest, Sha256}; pub type ProofAndWitness = (Proof, Node); +fn get_depth(i: GeneralizedIndex) -> Result { + log_2(i).ok_or(Error::InvalidGeneralizedIndex) +} + +fn get_index(i: GeneralizedIndex, depth: u32) -> usize { + i % 2usize.pow(depth) +} + pub fn get_subtree_index(i: GeneralizedIndex) -> Result { - let i_log2 = log_2(i).ok_or(Error::InvalidGeneralizedIndex)?; - Ok(i % 2usize.pow(i_log2)) + let depth = get_depth(i)?; + Ok(get_index(i, depth)) +} + +#[derive(Debug)] +pub struct Prover { + pub hasher: Sha256, + pub proof: Proof, + pub witness: Node, +} + +impl Prover { + pub fn set_leaf(&mut self, leaf: Node) { + self.proof.leaf = leaf; + } + + // Adds a node to the Merkle proof's branch. + // Assumes nodes are provided going from the bottom of the tree to the top. + pub fn extend_branch(&mut self, node: Node) { + self.proof.branch.push(node) + } + + pub fn set_witness(&mut self, witness: Node) { + self.witness = witness; + } + + pub fn compute_depth_and_index(&self, i: GeneralizedIndex) -> Result<(u32, usize), Error> { + let depth = get_depth(i)?; + Ok((depth, get_index(i, depth))) + } +} + +impl From for ProofAndWitness { + fn from(value: Prover) -> Self { + (value.proof, value.witness) + } +} + +impl From for Prover { + fn from(index: GeneralizedIndex) -> Self { + Self { + hasher: Sha256::new(), + proof: Proof { leaf: Default::default(), branch: vec![], index }, + witness: Default::default(), + } + } } /// Types that can produce Merkle proofs against themselves given a `GeneralizedIndex`. pub trait Prove { /// Provide a Merkle proof of the node in this type's merkle tree corresponding to the `index`. - fn prove(&mut self, index: GeneralizedIndex) -> Result; + fn prove(&mut self, prover: &mut Prover) -> Result<(), Error>; } /// Produce a Merkle proof (and corresponding witness) for the type `T` at the given `path` relative @@ -28,7 +80,9 @@ pub fn prove( path: Path, ) -> Result { let index = T::generalized_index(path)?; - data.prove(index) + let mut prover = index.into(); + data.prove(&mut prover)?; + Ok(prover.into()) } /// Contains data necessary to verify `leaf` was included under some witness "root" node @@ -51,15 +105,17 @@ impl Proof { pub fn prove_primitive( data: &mut T, - index: GeneralizedIndex, -) -> Result { + prover: &mut Prover, +) -> Result<(), Error> { + let index = prover.proof.index; if index != default_generalized_index() { return Err(Error::InvalidGeneralizedIndex) } let root = data.hash_tree_root()?; - let proof = Proof { leaf: root, branch: vec![], index }; - Ok((proof, root)) + prover.set_leaf(root); + prover.set_witness(root); + Ok(()) } pub fn is_valid_merkle_branch_for_generalized_index>( diff --git a/ssz-rs/src/uint.rs b/ssz-rs/src/uint.rs index d74b1903..d7b9db40 100644 --- a/ssz-rs/src/uint.rs +++ b/ssz-rs/src/uint.rs @@ -3,8 +3,8 @@ use crate::{ lib::*, merkleization::{ pack_bytes, - proofs::{prove_primitive, ProofAndWitness, Prove}, - GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, + proofs::{prove_primitive, ProofAndWitness, Prove, Prover}, + GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, }, ser::{Serialize, SerializeError}, Serializable, SimpleSerialize, BITS_PER_BYTE, @@ -76,11 +76,8 @@ macro_rules! define_uint { } impl Prove for $uint { - fn prove( - &mut self, - index: GeneralizedIndex, - ) -> Result { - prove_primitive(self, index) + fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { + prove_primitive(self, prover) } } @@ -156,8 +153,8 @@ impl GeneralizedIndexable for U256 { } impl Prove for U256 { - fn prove(&mut self, index: GeneralizedIndex) -> Result { - prove_primitive(self, index) + fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { + prove_primitive(self, prover) } } diff --git a/ssz-rs/src/vector.rs b/ssz-rs/src/vector.rs index 39015cc1..4c963283 100644 --- a/ssz-rs/src/vector.rs +++ b/ssz-rs/src/vector.rs @@ -3,10 +3,10 @@ use crate::{ error::{Error, InstanceError, TypeError}, lib::*, merkleization::{ - elements_to_chunks, get_power_of_two_ceil, merkleize, pack, - proofs::{ProofAndWitness, Prove}, + compute_merkle_tree, elements_to_chunks, get_power_of_two_ceil, merkleize, pack, + proofs::{Prove, Prover}, GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, Path, - PathElement, + PathElement, BYTES_PER_CHUNK, }, ser::{Serialize, SerializeError, Serializer}, Serializable, SimpleSerialize, @@ -226,16 +226,19 @@ impl Vector where T: SimpleSerialize, { - fn compute_hash_tree_root(&mut self) -> Result { + fn to_chunks(&mut self) -> Result, MerkleizationError> { if T::is_composite_type() { let count = self.len(); - let chunks = elements_to_chunks(self.data.iter_mut().enumerate(), count)?; - merkleize(&chunks, None) + elements_to_chunks(self.data.iter_mut().enumerate(), count) } else { - let chunks = pack(&self.data)?; - merkleize(&chunks, None) + pack(&self.data) } } + + fn compute_hash_tree_root(&mut self) -> Result { + let chunks = self.to_chunks()?; + merkleize(&chunks, None) + } } impl HashTreeRoot for Vector @@ -282,8 +285,40 @@ impl Prove for Vector where T: SimpleSerialize + Prove, { - fn prove(&mut self, index: GeneralizedIndex) -> Result { - todo!() + fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { + let (local_depth, local_index) = prover.compute_depth_and_index(prover.proof.index)?; + if local_index >= N { + return Err(MerkleizationError::InvalidIndex) + } + + let parent_index = prover.proof.index; + let chunk_count = Self::chunk_count(); + let leaf_count = chunk_count.next_power_of_two(); + let child_index = parent_index - leaf_count - local_index + 1; + prover.proof.index = child_index; + self[local_index].prove(prover)?; + prover.proof.index = parent_index; + + let chunks = self.to_chunks()?; + let tree = compute_merkle_tree(&mut prover.hasher, &chunks, leaf_count)?; + + // TODO: remove, but these should match at this point + debug_assert_eq!(&tree[prover.proof.index], prover.witness.as_ref()); + + let mut target = prover.proof.index; + for _ in 0..local_depth { + let sibling = if target % 2 != 0 { &tree[target - 1] } else { &tree[target + 1] }; + prover.extend_branch(sibling.try_into().expect("is correct size")); + target /= 2; + } + + // TODO: remove, but these should match at this point + debug_assert_eq!(&tree[1], self.hash_tree_root().unwrap().as_ref()); + + let root = &tree[1]; + prover.set_witness(root.try_into().expect("is correct size")); + + Ok(()) } } @@ -324,7 +359,7 @@ impl<'de, T: Serializable + serde::Deserialize<'de>, const N: usize> serde::Dese #[cfg(test)] mod tests { use super::*; - use crate::{list::List, serialize}; + use crate::{list::List, serialize, U256}; const COUNT: usize = 32; @@ -465,4 +500,15 @@ mod tests { let path = &[5.into()]; let _ = V::generalized_index(path).unwrap(); } + + #[test] + fn test_prove_vector() { + type V = Vector; + + let mut data = V::default(); + let mut prover = 11.into(); + data.prove(&mut prover).unwrap(); + let (proof, witness) = prover.into(); + proof.verify(witness).unwrap(); + } }