Skip to content

Commit

Permalink
wip: improve prover impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ralexstokes committed Mar 30, 2024
1 parent 0356ff8 commit a778e10
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 103 deletions.
12 changes: 9 additions & 3 deletions ssz-rs/src/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
de::{Deserialize, DeserializeError},
lib::*,
merkleization::{
proofs::{prove_primitive, Prove, Prover},
proofs::{NoChilden, Prove},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
},
ser::{Serialize, SerializeError},
Expand Down Expand Up @@ -63,8 +63,14 @@ impl GeneralizedIndexable for bool {
}

impl Prove for bool {
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
type Child = NoChilden;

fn chunks(&mut self) -> Result<Vec<u8>, MerkleizationError> {
let mut node = Node::default();
if *self {
node.as_mut()[0] = 1;
}
Ok(node.to_vec())
}
}

Expand Down
5 changes: 5 additions & 0 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub enum MerkleizationError {
InvalidPath(Vec<PathElement>),
InvalidDepth,
InvalidIndex,
NoChildren,
}

impl From<SerializeError> for MerkleizationError {
Expand All @@ -53,6 +54,10 @@ impl Display for MerkleizationError {
Self::InvalidPath(path) => write!(f, "invalid path {path:?}"),
Self::InvalidDepth => write!(f, "error computing depth for proof"),
Self::InvalidIndex => write!(f, "error computing index for proof"),
Self::NoChildren => write!(
f,
"requested to compute proof for a child which does not exist for this type"
),
}
}
}
Expand Down
143 changes: 115 additions & 28 deletions ssz-rs/src/merkleization/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
lib::*,
merkleization::{
default_generalized_index, generalized_index::log_2, GeneralizedIndex,
compute_merkle_tree, default_generalized_index, generalized_index::log_2, GeneralizedIndex,
GeneralizedIndexable, HashTreeRoot, MerkleizationError as Error, Node, Path,
},
};
Expand All @@ -23,6 +23,21 @@ pub fn get_subtree_index(i: GeneralizedIndex) -> Result<usize, Error> {
Ok(get_index(i, depth))
}

// Identify the generalized index that is the largest parent of `i` that fits in a perfect binary
// tree with `leaf_count` leaves. Return this index along with its depth in the tree
// and its index in the leaf layer.
pub fn compute_local_merkle_coordinates(
mut i: GeneralizedIndex,
leaf_count: usize,
) -> Result<(u32, usize, GeneralizedIndex), Error> {
let node_count = 2 * leaf_count - 1;
while i > node_count {
i /= 2;
}
let depth = get_depth(i)?;
Ok((depth, get_index(i, depth), i))
}

#[derive(Debug)]
pub struct Prover {
pub(crate) hasher: Sha256,
Expand All @@ -45,9 +60,49 @@ impl Prover {
self.witness = witness;
}

pub fn compute_depth_and_index(&self, i: GeneralizedIndex) -> Result<(u32, usize), Error> {
let depth = get_depth(i)?;
Ok((depth, get_index(i, depth)))
/// Derive a Merkle proof relative to `data` given the parameters in `self`.
pub fn compute_proof<T: Prove>(&mut self, data: &mut T) -> Result<(), Error> {
let chunk_count = T::chunk_count();
let leaf_count = chunk_count.next_power_of_two();
let parent_index = self.proof.index;
let (local_depth, local_index, local_generalized_index) =
compute_local_merkle_coordinates(parent_index, leaf_count)?;

let mut is_leaf_local = false;
if local_generalized_index < parent_index {
// NOTE: need to recurse to children to find ultimate leaf
let child_index = if parent_index % 2 == 0 {
parent_index / local_generalized_index
} else {
parent_index / local_generalized_index + 1
};
self.proof.index = child_index;
let child = data.child(local_index)?;
self.compute_proof(child)?;
self.proof.index = parent_index;
} else {
// NOTE: leaf is within the current object, set a flag to grab from merkle tree later
is_leaf_local = true;
}
let chunks = data.chunks()?;
let tree = compute_merkle_tree(&mut self.hasher, &chunks, leaf_count)?;

if is_leaf_local {
let leaf = &tree[parent_index];
self.set_leaf(leaf.try_into().expect("is correct size"));
}

let mut target = local_generalized_index;
for _ in 0..local_depth {
let sibling = if target % 2 != 0 { &tree[target - 1] } else { &tree[target + 1] };
self.extend_branch(sibling.try_into().expect("is correct size"));
target /= 2;
}

let root = &tree[1];
self.set_witness(root.try_into().expect("is correct size"));

Ok(())
}
}

Expand All @@ -68,20 +123,34 @@ impl From<GeneralizedIndex> for Prover {
}

/// Types that can produce Merkle proofs against themselves given a `GeneralizedIndex`.
pub trait Prove {
/// Provide a Merkle proof of the node in this type's merkle tree corresponding to the `index`.
fn prove(&mut self, prover: &mut Prover) -> Result<(), Error>;
pub trait Prove: GeneralizedIndexable {
type Child: Prove;

fn chunks(&mut self) -> Result<Vec<u8>, Error>;

fn child(&mut self, _index: usize) -> Result<&mut Self::Child, Error> {
Err(Error::NoChildren)
}
}

pub struct NoChilden;

impl GeneralizedIndexable for NoChilden {}

impl Prove for NoChilden {
type Child = bool;

fn chunks(&mut self) -> Result<Vec<u8>, Error> {
Err(Error::NoChildren)
}
}

/// Produce a Merkle proof (and corresponding witness) for the type `T` at the given `path` relative
/// to `T`.
pub fn prove<T: GeneralizedIndexable + Prove>(
data: &mut T,
path: Path,
) -> Result<ProofAndWitness, Error> {
pub fn prove<T: Prove>(data: &mut T, path: Path) -> Result<ProofAndWitness, Error> {
let index = T::generalized_index(path)?;
let mut prover = index.into();
data.prove(&mut prover)?;
let mut prover = Prover::from(index);
prover.compute_proof(data)?;
Ok(prover.into())
}

Expand Down Expand Up @@ -167,7 +236,7 @@ pub fn is_valid_merkle_branch<T: AsRef<[u8]>>(

#[cfg(test)]
mod tests {
use crate::U256;
use crate::{PathElement, SimpleSerialize, U256};

use super::*;

Expand Down Expand Up @@ -222,33 +291,51 @@ mod tests {
}

#[test]
fn test_proving_primitives() {
fn test_proving_primitives_fails_with_bad_path() {
let mut data = 8u8;
let (proof, witness) = prove(&mut data, &[]).unwrap();
let result = prove(&mut data, &[PathElement::Length]);
assert!(result.is_err());

let mut data = true;
let result = prove(&mut data, &[234.into()]);
assert!(result.is_err());
}

fn compute_and_verify_proof_for_path<T: SimpleSerialize + Prove>(data: &mut T, path: Path) {
let (proof, witness) = prove(data, path).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
}

#[test]
fn test_prove_primitives() {
let mut data = 8u8;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 0u8;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 234238u64;
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 0u128;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = u128::MAX;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = U256::from_str_radix(
"f8c2ed25e9c31399d4149dcaa48c51f394043a6a1297e65780a5979e3d7bb77c",
16,
)
.unwrap();
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = true;
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok())
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = false;
compute_and_verify_proof_for_path(&mut data, &[]);
}
}
25 changes: 15 additions & 10 deletions ssz-rs/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::{
lib::*,
merkleization::{
pack_bytes,
proofs::{prove_primitive, Prove, Prover},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
proofs::{NoChilden, Prove},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, BYTES_PER_CHUNK,
},
ser::{Serialize, SerializeError},
Serializable, SimpleSerialize, BITS_PER_BYTE,
Expand Down Expand Up @@ -58,9 +58,7 @@ macro_rules! define_uint {

impl HashTreeRoot for $uint {
fn hash_tree_root(&mut self) -> Result<Node, MerkleizationError> {
let mut root = vec![];
let _ = self.serialize(&mut root)?;
pack_bytes(&mut root);
let root = self.chunks()?;
Ok(root.as_slice().try_into().expect("is valid root"))
}

Expand All @@ -76,8 +74,13 @@ macro_rules! define_uint {
}

impl Prove for $uint {
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
type Child = NoChilden;

fn chunks(&mut self) -> Result<Vec<u8>, MerkleizationError> {
let mut root = Vec::with_capacity(BYTES_PER_CHUNK);
let _ = self.serialize(&mut root)?;
pack_bytes(&mut root);
Ok(root)
}
}

Expand Down Expand Up @@ -138,7 +141,7 @@ impl Deserialize for U256 {

impl HashTreeRoot for U256 {
fn hash_tree_root(&mut self) -> Result<Node, MerkleizationError> {
Ok(Node::try_from(self.as_le_bytes().as_ref()).expect("is right size"))
Ok(Node::try_from(self.chunks().unwrap().as_ref()).expect("is right size"))
}

fn is_composite_type() -> bool {
Expand All @@ -153,8 +156,10 @@ impl GeneralizedIndexable for U256 {
}

impl Prove for U256 {
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
type Child = NoChilden;

fn chunks(&mut self) -> Result<Vec<u8>, MerkleizationError> {
Ok(self.as_le_bytes().to_vec())
}
}

Expand Down
Loading

0 comments on commit a778e10

Please sign in to comment.