From a778e1064ea717bdf41436796837e2aae25d8dd3 Mon Sep 17 00:00:00 2001 From: Alex Stokes Date: Fri, 29 Mar 2024 18:58:16 -0600 Subject: [PATCH] wip: improve prover impl --- ssz-rs/src/boolean.rs | 12 +- ssz-rs/src/merkleization/mod.rs | 5 + ssz-rs/src/merkleization/proofs.rs | 143 +++++++++++++++---- ssz-rs/src/uint.rs | 25 ++-- ssz-rs/src/vector.rs | 212 ++++++++++++++++++++--------- 5 files changed, 294 insertions(+), 103 deletions(-) diff --git a/ssz-rs/src/boolean.rs b/ssz-rs/src/boolean.rs index 60514a41..bb463fbd 100644 --- a/ssz-rs/src/boolean.rs +++ b/ssz-rs/src/boolean.rs @@ -2,7 +2,7 @@ use crate::{ de::{Deserialize, DeserializeError}, lib::*, merkleization::{ - proofs::{prove_primitive, Prove, Prover}, + proofs::{NoChilden, Prove}, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, }, ser::{Serialize, SerializeError}, @@ -63,8 +63,14 @@ impl GeneralizedIndexable for bool { } impl Prove for bool { - fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { - prove_primitive(self, prover) + type Child = NoChilden; + + fn chunks(&mut self) -> Result, MerkleizationError> { + let mut node = Node::default(); + if *self { + node.as_mut()[0] = 1; + } + Ok(node.to_vec()) } } diff --git a/ssz-rs/src/merkleization/mod.rs b/ssz-rs/src/merkleization/mod.rs index 87686d0c..26239552 100644 --- a/ssz-rs/src/merkleization/mod.rs +++ b/ssz-rs/src/merkleization/mod.rs @@ -32,6 +32,7 @@ pub enum MerkleizationError { InvalidPath(Vec), InvalidDepth, InvalidIndex, + NoChildren, } impl From for MerkleizationError { @@ -53,6 +54,10 @@ impl Display for MerkleizationError { Self::InvalidPath(path) => write!(f, "invalid path {path:?}"), Self::InvalidDepth => write!(f, "error computing depth for proof"), Self::InvalidIndex => write!(f, "error computing index for proof"), + Self::NoChildren => write!( + f, + "requested to compute proof for a child which does not exist for this type" + ), } } } diff --git a/ssz-rs/src/merkleization/proofs.rs b/ssz-rs/src/merkleization/proofs.rs index 1a5ae8dc..683e1f3e 100644 --- a/ssz-rs/src/merkleization/proofs.rs +++ b/ssz-rs/src/merkleization/proofs.rs @@ -2,7 +2,7 @@ use crate::{ lib::*, merkleization::{ - default_generalized_index, generalized_index::log_2, GeneralizedIndex, + compute_merkle_tree, default_generalized_index, generalized_index::log_2, GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError as Error, Node, Path, }, }; @@ -23,6 +23,21 @@ pub fn get_subtree_index(i: GeneralizedIndex) -> Result { Ok(get_index(i, depth)) } +// Identify the generalized index that is the largest parent of `i` that fits in a perfect binary +// tree with `leaf_count` leaves. Return this index along with its depth in the tree +// and its index in the leaf layer. +pub fn compute_local_merkle_coordinates( + mut i: GeneralizedIndex, + leaf_count: usize, +) -> Result<(u32, usize, GeneralizedIndex), Error> { + let node_count = 2 * leaf_count - 1; + while i > node_count { + i /= 2; + } + let depth = get_depth(i)?; + Ok((depth, get_index(i, depth), i)) +} + #[derive(Debug)] pub struct Prover { pub(crate) hasher: Sha256, @@ -45,9 +60,49 @@ impl Prover { self.witness = witness; } - pub fn compute_depth_and_index(&self, i: GeneralizedIndex) -> Result<(u32, usize), Error> { - let depth = get_depth(i)?; - Ok((depth, get_index(i, depth))) + /// Derive a Merkle proof relative to `data` given the parameters in `self`. + pub fn compute_proof(&mut self, data: &mut T) -> Result<(), Error> { + let chunk_count = T::chunk_count(); + let leaf_count = chunk_count.next_power_of_two(); + let parent_index = self.proof.index; + let (local_depth, local_index, local_generalized_index) = + compute_local_merkle_coordinates(parent_index, leaf_count)?; + + let mut is_leaf_local = false; + if local_generalized_index < parent_index { + // NOTE: need to recurse to children to find ultimate leaf + let child_index = if parent_index % 2 == 0 { + parent_index / local_generalized_index + } else { + parent_index / local_generalized_index + 1 + }; + self.proof.index = child_index; + let child = data.child(local_index)?; + self.compute_proof(child)?; + self.proof.index = parent_index; + } else { + // NOTE: leaf is within the current object, set a flag to grab from merkle tree later + is_leaf_local = true; + } + let chunks = data.chunks()?; + let tree = compute_merkle_tree(&mut self.hasher, &chunks, leaf_count)?; + + if is_leaf_local { + let leaf = &tree[parent_index]; + self.set_leaf(leaf.try_into().expect("is correct size")); + } + + let mut target = local_generalized_index; + for _ in 0..local_depth { + let sibling = if target % 2 != 0 { &tree[target - 1] } else { &tree[target + 1] }; + self.extend_branch(sibling.try_into().expect("is correct size")); + target /= 2; + } + + let root = &tree[1]; + self.set_witness(root.try_into().expect("is correct size")); + + Ok(()) } } @@ -68,20 +123,34 @@ impl From for Prover { } /// Types that can produce Merkle proofs against themselves given a `GeneralizedIndex`. -pub trait Prove { - /// Provide a Merkle proof of the node in this type's merkle tree corresponding to the `index`. - fn prove(&mut self, prover: &mut Prover) -> Result<(), Error>; +pub trait Prove: GeneralizedIndexable { + type Child: Prove; + + fn chunks(&mut self) -> Result, Error>; + + fn child(&mut self, _index: usize) -> Result<&mut Self::Child, Error> { + Err(Error::NoChildren) + } +} + +pub struct NoChilden; + +impl GeneralizedIndexable for NoChilden {} + +impl Prove for NoChilden { + type Child = bool; + + fn chunks(&mut self) -> Result, Error> { + Err(Error::NoChildren) + } } /// Produce a Merkle proof (and corresponding witness) for the type `T` at the given `path` relative /// to `T`. -pub fn prove( - data: &mut T, - path: Path, -) -> Result { +pub fn prove(data: &mut T, path: Path) -> Result { let index = T::generalized_index(path)?; - let mut prover = index.into(); - data.prove(&mut prover)?; + let mut prover = Prover::from(index); + prover.compute_proof(data)?; Ok(prover.into()) } @@ -167,7 +236,7 @@ pub fn is_valid_merkle_branch>( #[cfg(test)] mod tests { - use crate::U256; + use crate::{PathElement, SimpleSerialize, U256}; use super::*; @@ -222,33 +291,51 @@ mod tests { } #[test] - fn test_proving_primitives() { + fn test_proving_primitives_fails_with_bad_path() { let mut data = 8u8; - let (proof, witness) = prove(&mut data, &[]).unwrap(); + let result = prove(&mut data, &[PathElement::Length]); + assert!(result.is_err()); + + let mut data = true; + let result = prove(&mut data, &[234.into()]); + assert!(result.is_err()); + } + + fn compute_and_verify_proof_for_path(data: &mut T, path: Path) { + let (proof, witness) = prove(data, path).unwrap(); assert_eq!(witness, data.hash_tree_root().unwrap()); let result = proof.verify(witness); assert!(result.is_ok()); + } + + #[test] + fn test_prove_primitives() { + let mut data = 8u8; + compute_and_verify_proof_for_path(&mut data, &[]); + + let mut data = 0u8; + compute_and_verify_proof_for_path(&mut data, &[]); let mut data = 234238u64; - let (proof, witness) = prove(&mut data, &[]).unwrap(); - assert_eq!(witness, data.hash_tree_root().unwrap()); - let result = proof.verify(witness); - assert!(result.is_ok()); + compute_and_verify_proof_for_path(&mut data, &[]); + + let mut data = 0u128; + compute_and_verify_proof_for_path(&mut data, &[]); + + let mut data = u128::MAX; + compute_and_verify_proof_for_path(&mut data, &[]); let mut data = U256::from_str_radix( "f8c2ed25e9c31399d4149dcaa48c51f394043a6a1297e65780a5979e3d7bb77c", 16, ) .unwrap(); - let (proof, witness) = prove(&mut data, &[]).unwrap(); - assert_eq!(witness, data.hash_tree_root().unwrap()); - let result = proof.verify(witness); - assert!(result.is_ok()); + compute_and_verify_proof_for_path(&mut data, &[]); let mut data = true; - let (proof, witness) = prove(&mut data, &[]).unwrap(); - assert_eq!(witness, data.hash_tree_root().unwrap()); - let result = proof.verify(witness); - assert!(result.is_ok()) + compute_and_verify_proof_for_path(&mut data, &[]); + + let mut data = false; + compute_and_verify_proof_for_path(&mut data, &[]); } } diff --git a/ssz-rs/src/uint.rs b/ssz-rs/src/uint.rs index 120d036b..db16ac0f 100644 --- a/ssz-rs/src/uint.rs +++ b/ssz-rs/src/uint.rs @@ -3,8 +3,8 @@ use crate::{ lib::*, merkleization::{ pack_bytes, - proofs::{prove_primitive, Prove, Prover}, - GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, + proofs::{NoChilden, Prove}, + GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, BYTES_PER_CHUNK, }, ser::{Serialize, SerializeError}, Serializable, SimpleSerialize, BITS_PER_BYTE, @@ -58,9 +58,7 @@ macro_rules! define_uint { impl HashTreeRoot for $uint { fn hash_tree_root(&mut self) -> Result { - let mut root = vec![]; - let _ = self.serialize(&mut root)?; - pack_bytes(&mut root); + let root = self.chunks()?; Ok(root.as_slice().try_into().expect("is valid root")) } @@ -76,8 +74,13 @@ macro_rules! define_uint { } impl Prove for $uint { - fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { - prove_primitive(self, prover) + type Child = NoChilden; + + fn chunks(&mut self) -> Result, MerkleizationError> { + let mut root = Vec::with_capacity(BYTES_PER_CHUNK); + let _ = self.serialize(&mut root)?; + pack_bytes(&mut root); + Ok(root) } } @@ -138,7 +141,7 @@ impl Deserialize for U256 { impl HashTreeRoot for U256 { fn hash_tree_root(&mut self) -> Result { - Ok(Node::try_from(self.as_le_bytes().as_ref()).expect("is right size")) + Ok(Node::try_from(self.chunks().unwrap().as_ref()).expect("is right size")) } fn is_composite_type() -> bool { @@ -153,8 +156,10 @@ impl GeneralizedIndexable for U256 { } impl Prove for U256 { - fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { - prove_primitive(self, prover) + type Child = NoChilden; + + fn chunks(&mut self) -> Result, MerkleizationError> { + Ok(self.as_le_bytes().to_vec()) } } diff --git a/ssz-rs/src/vector.rs b/ssz-rs/src/vector.rs index bece71bf..a5f6f864 100644 --- a/ssz-rs/src/vector.rs +++ b/ssz-rs/src/vector.rs @@ -3,8 +3,7 @@ use crate::{ error::{Error, InstanceError, TypeError}, lib::*, merkleization::{ - compute_merkle_tree, elements_to_chunks, get_power_of_two_ceil, merkleize, pack, - proofs::{Prove, Prover}, + elements_to_chunks, get_power_of_two_ceil, merkleize, pack, proofs::Prove, GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, Path, PathElement, }, @@ -285,47 +284,18 @@ impl Prove for Vector where T: SimpleSerialize + Prove, { - fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> { - let (local_depth, local_index) = prover.compute_depth_and_index(prover.proof.index)?; - if local_index >= N { - return Err(MerkleizationError::InvalidIndex) - } - - let chunk_count = Self::chunk_count(); - let leaf_count = chunk_count.next_power_of_two(); - let mut is_basic_type = false; - if T::is_composite_type() { - let parent_index = prover.proof.index; - let child_index = parent_index - leaf_count - local_index + 1; - prover.proof.index = child_index; - self[local_index].prove(prover)?; - prover.proof.index = parent_index; - } else { - // need to set leaf from merkle tree below... - is_basic_type = true; - } - let chunks = self.to_chunks()?; - let tree = compute_merkle_tree(&mut prover.hasher, &chunks, leaf_count)?; + type Child = T; - if is_basic_type { - let leaf = &tree[prover.proof.index]; - prover.set_leaf(leaf.try_into().expect("is correct size")); - } + fn chunks(&mut self) -> Result, MerkleizationError> { + self.to_chunks() + } - let mut target = prover.proof.index; - for _ in 0..local_depth { - let sibling = if target % 2 != 0 { &tree[target - 1] } else { &tree[target + 1] }; - prover.extend_branch(sibling.try_into().expect("is correct size")); - target /= 2; + fn child(&mut self, index: usize) -> Result<&mut Self::Child, MerkleizationError> { + if index >= N { + Err(MerkleizationError::InvalidIndex) + } else { + Ok(&mut self[index]) } - - // TODO: remove, but these should match at this point - debug_assert_eq!(&tree[1], self.hash_tree_root().unwrap().as_ref()); - - let root = &tree[1]; - prover.set_witness(root.try_into().expect("is correct size")); - - Ok(()) } } @@ -366,7 +336,11 @@ impl<'de, T: Serializable + serde::Deserialize<'de>, const N: usize> serde::Dese #[cfg(test)] mod tests { use super::*; - use crate::{list::List, merkleization::proofs::prove, serialize, U256}; + use crate::{ + list::List, + merkleization::proofs::{prove, Prover}, + serialize, U256, + }; const COUNT: usize = 32; @@ -509,7 +483,69 @@ mod tests { } #[test] - fn test_prove_vector_over_chunk_sized_primitive() { + fn test_generalized_index_for_vector_over_non_aligned_vector() { + type W = Vector; + type V = Vector; + + let path = &[0.into(), 0.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 2); + + let path = &[0.into(), 1.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 2); + + let path = &[1.into(), 0.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 3); + + let path = &[1.into(), 1.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 3); + } + + #[test] + fn test_generalized_index_for_vector_over_aligned_vector() { + type W = Vector; + type V = Vector; + + let path = &[0.into(), 0.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 4); + + let path = &[0.into(), 1.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 5); + + let path = &[1.into(), 0.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 6); + + let path = &[1.into(), 1.into()]; + let index = V::generalized_index(path).unwrap(); + assert_eq!(index, 7); + } + + fn compute_and_verify_proof( + data: &mut T, + path: Path, + expected_index: GeneralizedIndex, + ) { + let (proof, witness) = prove(data, path).unwrap(); + assert!(proof.verify(witness).is_ok()); + + let index = T::generalized_index(path).unwrap(); + assert_eq!(expected_index, index); + let mut prover = Prover::from(expected_index); + prover.compute_proof(data).unwrap(); + let (proof_from_index, witness_from_index) = prover.into(); + assert_eq!(proof, proof_from_index); + assert_eq!(witness, witness_from_index); + assert!(proof.verify(witness).is_ok()); + } + + #[test] + fn test_prove_vector_over_aligned_primitive() { type V = Vector; let mut data = V::try_from(vec![ @@ -524,36 +560,88 @@ mod tests { .unwrap(); let path = &[3.into()]; - let (proof, witness) = prove(&mut data, path).unwrap(); - let expected_index = 11; - let index = V::generalized_index(path).unwrap(); - assert_eq!(expected_index, index); - let mut prover = expected_index.into(); - data.prove(&mut prover).unwrap(); - let (proof_from_index, witness_from_index) = prover.into(); - assert_eq!(proof, proof_from_index); - assert_eq!(witness, witness_from_index); - assert!(proof.verify(witness).is_ok()); + compute_and_verify_proof(&mut data, path, expected_index); } #[test] - fn test_prove_vector_over_nonchunk_sized_primitive() { + fn test_prove_vector_over_non_aligned_primitive() { type V = Vector; let mut data = V::try_from(vec![23, 34, 45, 56, 67, 78, 11]).unwrap(); let path = &[3.into()]; - let (proof, witness) = prove(&mut data, path).unwrap(); + let expected_index = 2; + compute_and_verify_proof(&mut data, path, expected_index); + } + + #[test] + fn test_prove_vector_over_vector_of_non_aligned_primitives() { + type W = Vector; + type V = Vector; + let inner = W::try_from(vec![true, true]).unwrap(); + let mut data = V::try_from(vec![inner.clone(), inner]).unwrap(); + + // prove into non-leaf + let path = &[0.into()]; let expected_index = 2; - let index = V::generalized_index(path).unwrap(); - assert_eq!(expected_index, index); - let mut prover = expected_index.into(); - data.prove(&mut prover).unwrap(); - let (proof_from_index, witness_from_index) = prover.into(); - assert_eq!(proof, proof_from_index); - assert_eq!(witness, witness_from_index); - assert!(proof_from_index.verify(witness_from_index).is_ok()); + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into()]; + let expected_index = 3; + compute_and_verify_proof(&mut data, path, expected_index); + + // prove into leaf + let path = &[0.into(), 0.into()]; + let expected_index = 2; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[0.into(), 1.into()]; + let expected_index = 2; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into(), 0.into()]; + let expected_index = 3; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into(), 1.into()]; + let expected_index = 3; + compute_and_verify_proof(&mut data, path, expected_index); + } + + #[test] + fn test_prove_vector_over_vector_of_aligned_primitives() { + type W = Vector; + type V = Vector; + + let inner = W::try_from(vec![U256::from(1), U256::from(2)]).unwrap(); + let mut data = V::try_from(vec![inner.clone(), inner]).unwrap(); + + // prove into non-leaf + let path = &[0.into()]; + let expected_index = 2; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into()]; + let expected_index = 3; + compute_and_verify_proof(&mut data, path, expected_index); + + // prove into leaf + let path = &[0.into(), 0.into()]; + let expected_index = 4; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[0.into(), 1.into()]; + let expected_index = 5; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into(), 0.into()]; + let expected_index = 6; + compute_and_verify_proof(&mut data, path, expected_index); + + let path = &[1.into(), 1.into()]; + let expected_index = 7; + compute_and_verify_proof(&mut data, path, expected_index); } }