Skip to content

Commit

Permalink
refactor proving to consolidate implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ralexstokes committed Mar 30, 2024
1 parent a163ef9 commit 1929aa4
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 143 deletions.
14 changes: 10 additions & 4 deletions ssz-rs/src/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
de::{Deserialize, DeserializeError},
lib::*,
merkleization::{
proofs::{prove_primitive, Prove, Prover},
GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
proofs::Prove, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node,
BYTES_PER_CHUNK,
},
ser::{Serialize, SerializeError},
Serializable, SimpleSerialize,
Expand Down Expand Up @@ -63,8 +63,14 @@ impl GeneralizedIndexable for bool {
}

impl Prove for bool {
fn prove(&mut self, prover: &mut Prover) -> Result<(), MerkleizationError> {
prove_primitive(self, prover)
type InnerElement = ();

fn chunks(&mut self) -> Result<Vec<u8>, MerkleizationError> {
let mut vec = vec![0u8; BYTES_PER_CHUNK];
if *self {
vec[0] = 1;
}
Ok(vec)
}
}

Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ mod exports {
error::{Error as SimpleSerializeError, InstanceError, TypeError},
list::List,
merkleization::{
generalized_index::default_generalized_index,
multiproofs,
proofs::{self, is_valid_merkle_branch},
GeneralizedIndex, GeneralizedIndexable, HashTreeRoot, MerkleizationError, Node, Path,
Expand Down
7 changes: 3 additions & 4 deletions ssz-rs/src/merkleization/merkleize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
ser::Serialize,
GeneralizedIndex,
};
use alloy_primitives::hex;
use sha2::{Digest, Sha256};

/// Types that can provide the root of their corresponding Merkle tree following the SSZ spec.
Expand Down Expand Up @@ -45,7 +44,7 @@ where
Ok(buffer)
}

pub fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) {
fn hash_nodes(hasher: &mut Sha256, a: &[u8], b: &[u8], out: &mut [u8]) {
hasher.update(a);
hasher.update(b);
out.copy_from_slice(&hasher.finalize_reset());
Expand Down Expand Up @@ -240,6 +239,7 @@ impl Index<GeneralizedIndex> for Tree {
}
}

#[cfg(feature = "serde")]
impl std::fmt::Debug for Tree {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.chunks(BYTES_PER_CHUNK).map(hex::encode)).finish()
Expand Down Expand Up @@ -291,8 +291,7 @@ pub fn compute_merkle_tree(
#[cfg(test)]
mod tests {
use super::*;
use crate as ssz_rs;
use crate::{merkleization::default_generalized_index, prelude::*};
use crate::prelude::*;

macro_rules! hex {
($input:expr) => {
Expand Down
22 changes: 16 additions & 6 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pub mod proofs;

use crate::{lib::*, ser::SerializeError};
pub use generalized_index::{
default_generalized_index, get_power_of_two_ceil, GeneralizedIndex, GeneralizedIndexable, Path,
PathElement,
get_power_of_two_ceil, GeneralizedIndex, GeneralizedIndexable, Path, PathElement,
};
pub use merkleize::*;
pub use node::*;
Expand All @@ -30,8 +29,13 @@ pub enum MerkleizationError {
InvalidPathElement(PathElement),
/// Signals an invalid path when walking a `GeneralizedIndexable` type
InvalidPath(Vec<PathElement>),
InvalidDepth,
InvalidIndex,
/// Attempt to prove an inner element outside the bounds of what the implementing type
/// supports.
InvalidInnerIndex,
/// Attempt to prove an inner element for a "basic" type that doesn't have one
NoInnerElement,
/// Attempt to turn an instance of a type in Merkle chunks when this is not supported
NotChunkable,
}

impl From<SerializeError> for MerkleizationError {
Expand All @@ -51,8 +55,14 @@ impl Display for MerkleizationError {
Self::InvalidGeneralizedIndex => write!(f, "invalid generalized index"),
Self::InvalidPathElement(element) => write!(f, "invalid path element {element:?}"),
Self::InvalidPath(path) => write!(f, "invalid path {path:?}"),
Self::InvalidDepth => write!(f, "error computing depth for proof"),
Self::InvalidIndex => write!(f, "error computing index for proof"),
Self::InvalidInnerIndex => write!(f, "requested to compute proof for an inner element outside the bounds of what this type supports"),
Self::NoInnerElement => write!(
f,
"requested to compute proof for an inner element which does not exist for this type"
),
Self::NotChunkable => {
write!(f, "requested to compute chunks for a type which does not support this")
}
}
}
}
Expand Down
189 changes: 135 additions & 54 deletions ssz-rs/src/merkleization/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
use crate::{
lib::*,
merkleization::{
default_generalized_index, generalized_index::log_2, GeneralizedIndex,
GeneralizedIndexable, HashTreeRoot, MerkleizationError as Error, Node, Path,
compute_merkle_tree, generalized_index::log_2, GeneralizedIndex, GeneralizedIndexable,
MerkleizationError as Error, Node, Path,
},
};
use sha2::{Digest, Sha256};
Expand All @@ -23,31 +23,84 @@ pub fn get_subtree_index(i: GeneralizedIndex) -> Result<usize, Error> {
Ok(get_index(i, depth))
}

// Identify the generalized index that is the largest parent of `i` that fits in a perfect binary
// tree with `leaf_count` leaves. Return this index along with its depth in the tree
// and its index in the leaf layer.
pub fn compute_local_merkle_coordinates(
mut i: GeneralizedIndex,
leaf_count: usize,
) -> Result<(u32, usize, GeneralizedIndex), Error> {
let node_count = 2 * leaf_count - 1;
while i > node_count {
i /= 2;
}
let depth = get_depth(i)?;
Ok((depth, get_index(i, depth), i))
}

#[derive(Debug)]
pub struct Prover {
pub(crate) hasher: Sha256,
pub proof: Proof,
pub witness: Node,
hasher: Sha256,
proof: Proof,
witness: Node,
}

impl Prover {
pub fn set_leaf(&mut self, leaf: Node) {
self.proof.leaf = leaf;
fn set_leaf(&mut self, leaf: &[u8]) {
self.proof.leaf = leaf.try_into().expect("is correct size");
}

// Adds a node to the Merkle proof's branch.
// Assumes nodes are provided going from the bottom of the tree to the top.
pub fn extend_branch(&mut self, node: Node) {
self.proof.branch.push(node)
fn extend_branch(&mut self, node: &[u8]) {
self.proof.branch.push(node.try_into().expect("is correct size"))
}

pub fn set_witness(&mut self, witness: Node) {
self.witness = witness;
fn set_witness(&mut self, witness: &[u8]) {
self.witness = witness.try_into().expect("is correct size");
}

pub fn compute_depth_and_index(&self, i: GeneralizedIndex) -> Result<(u32, usize), Error> {
let depth = get_depth(i)?;
Ok((depth, get_index(i, depth)))
/// Derive a Merkle proof relative to `data` given the parameters in `self`.
pub fn compute_proof<T: Prove>(&mut self, data: &mut T) -> Result<(), Error> {
let chunk_count = T::chunk_count();
let leaf_count = chunk_count.next_power_of_two();
let parent_index = self.proof.index;
let (local_depth, local_index, local_generalized_index) =
compute_local_merkle_coordinates(parent_index, leaf_count)?;

let mut is_leaf_local = false;
if local_generalized_index < parent_index {
// NOTE: need to recurse to children to find ultimate leaf
let child_index = if parent_index % 2 == 0 {
parent_index / local_generalized_index
} else {
parent_index / local_generalized_index + 1
};
self.proof.index = child_index;
let child = data.inner_element(local_index)?;
self.compute_proof(child)?;
self.proof.index = parent_index;
} else {
// NOTE: leaf is within the current object, set a flag to grab from merkle tree later
is_leaf_local = true;
}
let chunks = data.chunks()?;
let tree = compute_merkle_tree(&mut self.hasher, &chunks, leaf_count)?;

if is_leaf_local {
self.set_leaf(&tree[parent_index]);
}

let mut target = local_generalized_index;
for _ in 0..local_depth {
let sibling = if target % 2 != 0 { &tree[target - 1] } else { &tree[target + 1] };
self.extend_branch(sibling);
target /= 2;
}

self.set_witness(&tree[1]);

Ok(())
}
}

Expand All @@ -67,21 +120,46 @@ impl From<GeneralizedIndex> for Prover {
}
}

/// Types that can produce Merkle proofs against themselves given a `GeneralizedIndex`.
pub trait Prove {
/// Provide a Merkle proof of the node in this type's merkle tree corresponding to the `index`.
fn prove(&mut self, prover: &mut Prover) -> Result<(), Error>;
/// Required functionality to support computing Merkle proofs.
pub trait Prove: GeneralizedIndexable {
type InnerElement: Prove;

/// Compute the "chunks" of this type as required for the SSZ merkle tree computation.
/// Default implementation signals an error. Implementing types should override
/// to provide the correct behavior.
fn chunks(&mut self) -> Result<Vec<u8>, Error> {
Err(Error::NotChunkable)
}

/// Provide a reference to a member element of a composite type.
/// Default implementation signals an error. Implementing types should override
/// to provide the correct behavior.
fn inner_element(&mut self, _index: usize) -> Result<&mut Self::InnerElement, Error> {
Err(Error::NoInnerElement)
}
}

// Implement `GeneralizedIndexable` for `()` for use as a marker type in `Prove`.
impl GeneralizedIndexable for () {
fn compute_generalized_index(
_parent: GeneralizedIndex,
path: Path,
) -> Result<GeneralizedIndex, Error> {
Err(Error::InvalidPath(path.to_vec()))
}
}

// Implement the default `Prove` functionality for use of `()` as a marker type.
impl Prove for () {
type InnerElement = ();
}

/// Produce a Merkle proof (and corresponding witness) for the type `T` at the given `path` relative
/// to `T`.
pub fn prove<T: GeneralizedIndexable + Prove>(
data: &mut T,
path: Path,
) -> Result<ProofAndWitness, Error> {
pub fn prove<T: Prove>(data: &mut T, path: Path) -> Result<ProofAndWitness, Error> {
let index = T::generalized_index(path)?;
let mut prover = index.into();
data.prove(&mut prover)?;
let mut prover = Prover::from(index);
prover.compute_proof(data)?;
Ok(prover.into())
}

Expand All @@ -103,21 +181,6 @@ impl Proof {
}
}

pub fn prove_primitive<T: HashTreeRoot + ?Sized>(
data: &mut T,
prover: &mut Prover,
) -> Result<(), Error> {
let index = prover.proof.index;
if index != default_generalized_index() {
return Err(Error::InvalidGeneralizedIndex)
}

let root = data.hash_tree_root()?;
prover.set_leaf(root);
prover.set_witness(root);
Ok(())
}

pub fn is_valid_merkle_branch_for_generalized_index<T: AsRef<[u8]>>(
leaf: Node,
branch: &[T],
Expand Down Expand Up @@ -167,7 +230,7 @@ pub fn is_valid_merkle_branch<T: AsRef<[u8]>>(

#[cfg(test)]
mod tests {
use crate::U256;
use crate::{PathElement, SimpleSerialize, U256};

use super::*;

Expand Down Expand Up @@ -222,33 +285,51 @@ mod tests {
}

#[test]
fn test_proving_primitives() {
fn test_proving_primitives_fails_with_bad_path() {
let mut data = 8u8;
let (proof, witness) = prove(&mut data, &[]).unwrap();
let result = prove(&mut data, &[PathElement::Length]);
assert!(result.is_err());

let mut data = true;
let result = prove(&mut data, &[234.into()]);
assert!(result.is_err());
}

fn compute_and_verify_proof_for_path<T: SimpleSerialize + Prove>(data: &mut T, path: Path) {
let (proof, witness) = prove(data, path).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
}

#[test]
fn test_prove_primitives() {
let mut data = 8u8;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 0u8;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 234238u64;
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = 0u128;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = u128::MAX;
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = U256::from_str_radix(
"f8c2ed25e9c31399d4149dcaa48c51f394043a6a1297e65780a5979e3d7bb77c",
16,
)
.unwrap();
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok());
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = true;
let (proof, witness) = prove(&mut data, &[]).unwrap();
assert_eq!(witness, data.hash_tree_root().unwrap());
let result = proof.verify(witness);
assert!(result.is_ok())
compute_and_verify_proof_for_path(&mut data, &[]);

let mut data = false;
compute_and_verify_proof_for_path(&mut data, &[]);
}
}
Loading

0 comments on commit 1929aa4

Please sign in to comment.