diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d452475..1c362f24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## 0.1.2 (2023-02-17) + +- Fixed `Rpo256::hash` pad that was panicking on input (#44) +- Added `MerklePath` wrapper to encapsulate Merkle opening verification and root computation (#53) +- Added `NodeIndex` Merkle wrapper to encapsulate Merkle tree traversal and mappings (#54) + ## 0.1.1 (2023-02-06) - Introduced `merge_in_domain` for the RPO hash function, to allow using a specified domain value in the second capacity register when hashing two digests together. @@ -8,6 +14,6 @@ - Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup. - Hash module with the BLAKE3 and Rescue Prime Optimized hash functions. - - BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output. + - BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output. - RPO is implemented with 256-bit output. - Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function. diff --git a/Cargo.toml b/Cargo.toml index 815a42be..bdb47028 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "miden-crypto" -version = "0.1.1" +version = "0.1.2" description="Miden Cryptographic primitives" authors = ["miden contributors"] readme="README.md" diff --git a/README.md b/README.md index f50fbc11..e01e1f2e 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,14 @@ For performance benchmarks of these hash functions and their comparison to other [Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are: * `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64. +* `SimpleSmt`: a Sparse Merkle Tree, mapping 63-bit keys to 4-element leaf values. * `MerklePathSet`: a collection of Merkle authentication paths all resolving to the same root. The length of the paths can be at most 64. +The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state. + +## Extra +[Root module](./src/lib.rs) provides a set of constants, types, aliases, and utils required to use the primitives of this library. + ## Crate features This crate can be compiled with the following features: @@ -25,5 +31,21 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/ To compile with `no_std`, disable default features via `--no-default-features` flag. +## Testing + +You can use cargo defaults to test the library: + +```shell +cargo test +``` + +However, some of the functions are heavy and might take a while for the tests to complete. In order to test in release mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified. + +We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation. + +```shell +RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release +``` + ## License This project is [MIT licensed](./LICENSE). diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 0d2fde3d..67fc1a55 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -94,61 +94,64 @@ impl Hasher for Rpo256 { type Digest = RpoDigest; fn hash(bytes: &[u8]) -> Self::Digest { - // compute the number of elements required to represent the string; we will be processing - // the string in BINARY_CHUNK_SIZE-byte chunks, thus the number of elements will be equal - // to the number of such chunks (including a potential partial chunk at the end). - let num_elements = if bytes.len() % BINARY_CHUNK_SIZE == 0 { - bytes.len() / BINARY_CHUNK_SIZE - } else { - bytes.len() / BINARY_CHUNK_SIZE + 1 - }; - - // initialize state to all zeros, except for the first element of the capacity part, which - // is set to the number of elements to be hashed. this is done so that adding zero elements - // at the end of the list always results in a different hash. + // initialize the state with zeroes let mut state = [ZERO; STATE_WIDTH]; - state[CAPACITY_RANGE.start] = Felt::new(num_elements as u64); - // break the string into BINARY_CHUNK_SIZE-byte chunks, convert each chunk into a field - // element, and absorb the element into the rate portion of the state. we use - // BINARY_CHUNK_SIZE-byte chunks because every BINARY_CHUNK_SIZE-byte chunk is guaranteed - // to map to some field element. - let mut i = 0; + // set the capacity (first element) to a flag on whether or not the input length is evenly + // divided by the rate. this will prevent collisions between padded and non-padded inputs, + // and will rule out the need to perform an extra permutation in case of evenly divided + // inputs. + let is_rate_multiple = bytes.len() % RATE_WIDTH == 0; + if !is_rate_multiple { + state[CAPACITY_RANGE.start] = ONE; + } + + // initialize a buffer to receive the little-endian elements. let mut buf = [0_u8; 8]; - for chunk in bytes.chunks(BINARY_CHUNK_SIZE) { - if i < num_elements - 1 { + + // iterate the chunks of bytes, creating a field element from each chunk and copying it + // into the state. + // + // every time the rate range is filled, a permutation is performed. if the final value of + // `i` is not zero, then the chunks count wasn't enough to fill the state range, and an + // additional permutation must be performed. + let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { + // the last element of the iteration may or may not be a full chunk. if it's not, then + // we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`. + // this will avoid collisions. + if chunk.len() == BINARY_CHUNK_SIZE { buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); } else { - // if we are dealing with the last chunk, it may be smaller than BINARY_CHUNK_SIZE - // bytes long, so we need to handle it slightly differently. We also append a byte - // with value 1 to the end of the string; this pads the string in such a way that - // adding trailing zeros results in different hash - let chunk_len = chunk.len(); - buf = [0_u8; 8]; - buf[..chunk_len].copy_from_slice(chunk); - buf[chunk_len] = 1; + buf.fill(0); + buf[..chunk.len()].copy_from_slice(chunk); + buf[chunk.len()] = 1; } - // convert the bytes into a field element and absorb it into the rate portion of the - // state; if the rate is filled up, apply the Rescue permutation and start absorbing - // again from zero index. + // set the current rate element to the input. since we take at most 7 bytes, we are + // guaranteed that the inputs data will fit into a single field element. state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); - i += 1; - if i % RATE_WIDTH == 0 { + + // proceed filling the range. if it's full, then we apply a permutation and reset the + // counter to the beginning of the range. + if i == RATE_WIDTH - 1 { Self::apply_permutation(&mut state); - i = 0; + 0 + } else { + i + 1 } - } + }); // if we absorbed some elements but didn't apply a permutation to them (would happen when - // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. - // we don't need to apply any extra padding because we injected total number of elements - // in the input list into the capacity portion of the state during initialization. - if i > 0 { + // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we + // don't need to apply any extra padding because the first capacity element containts a + // flag indicating whether the input is evenly divisible by the rate. + if i != 0 { + state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); + state[RATE_RANGE.start + i] = ONE; Self::apply_permutation(&mut state); } - // return the first 4 elements of the state as hash result + // return the first 4 elements of the rate as hash result. RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) } diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index b10e02b0..d6379bbd 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -2,7 +2,9 @@ use super::{ Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH, ZERO, }; +use crate::utils::collections::{BTreeSet, Vec}; use core::convert::TryInto; +use proptest::prelude::*; use rand_utils::rand_value; #[test] @@ -193,6 +195,43 @@ fn hash_test_vectors() { } } +#[test] +fn sponge_bytes_with_remainder_length_wont_panic() { + // this test targets to assert that no panic will happen with the edge case of having an inputs + // with length that is not divisible by the used binary chunk size. 113 is a non-negligible + // input length that is prime; hence guaranteed to not be divisible by any choice of chunk + // size. + // + // this is a preliminary test to the fuzzy-stress of proptest. + Rpo256::hash(&vec![0; 113]); +} + +#[test] +fn sponge_collision_for_wrapped_field_element() { + let a = Rpo256::hash(&[0; 8]); + let b = Rpo256::hash(&Felt::MODULUS.to_le_bytes()); + assert_ne!(a, b); +} + +#[test] +fn sponge_zeroes_collision() { + let mut zeroes = Vec::with_capacity(255); + let mut set = BTreeSet::new(); + (0..255).for_each(|_| { + let hash = Rpo256::hash(&zeroes); + zeroes.push(0); + // panic if a collision was found + assert!(set.insert(hash)); + }); +} + +proptest! { + #[test] + fn rpo256_wont_panic_with_arbitrary_input(ref vec in any::>()) { + Rpo256::hash(&vec); + } +} + const EXPECTED: [[Felt; 4]; 19] = [ [ Felt::new(1502364727743950833), diff --git a/src/lib.rs b/src/lib.rs index 8cf4e3c1..a68c2bf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,3 +38,32 @@ pub const ZERO: Felt = Felt::ZERO; /// Field element representing ONE in the Miden base filed. pub const ONE: Felt = Felt::ONE; + +// TESTS +// ================================================================================================ + +#[test] +#[should_panic] +fn debug_assert_is_checked() { + // enforce the release checks to always have `RUSTFLAGS="-C debug-assertions". + // + // some upstream tests are performed with `debug_assert`, and we want to assert its correctness + // downstream. + // + // for reference, check + // https://github.com/0xPolygonMiden/miden-vm/issues/433 + debug_assert!(false); +} + +#[test] +#[should_panic] +#[allow(arithmetic_overflow)] +fn overflow_panics_for_test() { + // overflows might be disabled if tests are performed in release mode. these are critical, + // mandatory checks as overflows might be attack vectors. + // + // to enable overflow checks in release mode, ensure `RUSTFLAGS="-C overflow-checks"` + let a = 1_u64; + let b = 64; + assert_ne!(a << b, 0); +} diff --git a/src/merkle/index.rs b/src/merkle/index.rs new file mode 100644 index 00000000..95d71237 --- /dev/null +++ b/src/merkle/index.rs @@ -0,0 +1,126 @@ +use super::{Felt, MerkleError, RpoDigest, StarkField}; + +// NODE INDEX +// ================================================================================================ + +/// A Merkle tree address to an arbitrary node. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct NodeIndex { + depth: u8, + value: u64, +} + +impl NodeIndex { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Creates a new node index. + pub const fn new(depth: u8, value: u64) -> Self { + Self { depth, value } + } + + /// Creates a node index from a pair of field elements representing the depth and value. + /// + /// # Errors + /// + /// Will error if the `u64` representation of the depth doesn't fit a `u8`. + pub fn from_elements(depth: &Felt, value: &Felt) -> Result { + let depth = depth.as_int(); + let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?; + let value = value.as_int(); + Ok(Self::new(depth, value)) + } + + /// Creates a new node index pointing to the root of the tree. + pub const fn root() -> Self { + Self { depth: 0, value: 0 } + } + + /// Mutates the instance and returns it, replacing the depth. + pub const fn with_depth(mut self, depth: u8) -> Self { + self.depth = depth; + self + } + + /// Computes the value of the sibling of the current node. + pub fn sibling(mut self) -> Self { + self.value ^= 1; + self + } + + // PROVIDERS + // -------------------------------------------------------------------------------------------- + + /// Builds a node to be used as input of a hash function when computing a Merkle path. + /// + /// Will evaluate the parity of the current instance to define the result. + pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] { + if self.is_value_odd() { + [sibling, slf] + } else { + [slf, sibling] + } + } + + /// Returns the scalar representation of the depth/value pair. + /// + /// It is computed as `2^depth + value`. + pub const fn to_scalar_index(&self) -> u64 { + (1 << self.depth as u64) + self.value + } + + /// Returns the depth of the current instance. + pub const fn depth(&self) -> u8 { + self.depth + } + + /// Returns the value of the current depth. + pub const fn value(&self) -> u64 { + self.value + } + + /// Returns true if the current value fits the current depth for a binary tree. + pub const fn is_valid(&self) -> bool { + self.value < (1 << self.depth as u64) + } + + /// Returns true if the current instance points to a right sibling node. + pub const fn is_value_odd(&self) -> bool { + (self.value & 1) == 1 + } + + /// Returns `true` if the depth is `0`. + pub const fn is_root(&self) -> bool { + self.depth == 0 + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Traverse one level towards the root, decrementing the depth by `1`. + pub fn move_up(&mut self) -> &mut Self { + self.depth = self.depth.saturating_sub(1); + self.value >>= 1; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn arbitrary_index_wont_panic_on_move_up( + depth in prop::num::u8::ANY, + value in prop::num::u64::ANY, + count in prop::num::u8::ANY, + ) { + let mut index = NodeIndex::new(depth, value); + for _ in 0..count { + index.move_up(); + } + } + } +} diff --git a/src/merkle/merkle_path_set.rs b/src/merkle/merkle_path_set.rs deleted file mode 100644 index 8210285b..00000000 --- a/src/merkle/merkle_path_set.rs +++ /dev/null @@ -1,344 +0,0 @@ -use super::{BTreeMap, MerkleError, Rpo256, Vec, Word, ZERO}; - -// MERKLE PATH SET -// ================================================================================================ - -/// A set of Merkle paths. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MerklePathSet { - root: Word, - total_depth: u32, - paths: BTreeMap>, -} - -impl MerklePathSet { - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - - /// Returns an empty MerklePathSet. - pub fn new(depth: u32) -> Result { - let root = [ZERO; 4]; - let paths = BTreeMap::>::new(); - - Ok(Self { - root, - total_depth: depth, - paths, - }) - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters - /// specify the leaf node at which the path starts. - /// - /// # Errors - /// Returns an error if: - /// - The specified index is not valid in the context of this Merkle path set (i.e., the index - /// implies a greater depth than is specified for this set). - /// - The specified path is not consistent with other paths in the set (i.e., resolves to a - /// different root). - pub fn add_path( - &mut self, - index: u64, - value: Word, - path: Vec, - ) -> Result<(), MerkleError> { - let depth = (path.len() + 1) as u32; - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); - } - - // Actual number of node in tree - let pos = 2u64.pow(self.total_depth) + index; - - // Index of the leaf path in map. Paths of neighboring leaves are stored in one key-value pair - let half_pos = pos / 2; - - let mut extended_path = path; - if is_even(pos) { - extended_path.insert(0, value); - } else { - extended_path.insert(1, value); - } - - let root_of_current_path = compute_path_root(&extended_path, depth, index); - if self.root == [ZERO; 4] { - self.root = root_of_current_path; - } else if self.root != root_of_current_path { - return Err(MerkleError::InvalidPath(extended_path)); - } - self.paths.insert(half_pos, extended_path); - - Ok(()) - } - - /// Returns the root to which all paths in this set resolve. - pub fn root(&self) -> Word { - self.root - } - - /// Returns the depth of the Merkle tree implied by the paths stored in this set. - /// - /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. - pub fn depth(&self) -> u32 { - self.total_depth - } - - /// Returns a node at the specified index. - /// - /// # Errors - /// Returns an error if: - /// * The specified index not valid for the depth of structure. - /// * Requested node does not exist in the set. - pub fn get_node(&self, depth: u32, index: u64) -> Result { - if index >= 2u64.pow(self.total_depth) { - return Err(MerkleError::InvalidIndex(self.total_depth, index)); - } - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); - } - - let pos = 2u64.pow(depth) + index; - let index = pos / 2; - - match self.paths.get(&index) { - None => Err(MerkleError::NodeNotInSet(index)), - Some(path) => { - if is_even(pos) { - Ok(path[0]) - } else { - Ok(path[1]) - } - } - } - } - - /// Returns a Merkle path to the node at the specified index. The node itself is - /// not included in the path. - /// - /// # Errors - /// Returns an error if: - /// * The specified index not valid for the depth of structure. - /// * Node of the requested path does not exist in the set. - pub fn get_path(&self, depth: u32, index: u64) -> Result, MerkleError> { - if index >= 2u64.pow(self.total_depth) { - return Err(MerkleError::InvalidIndex(self.total_depth, index)); - } - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); - } - - let pos = 2u64.pow(depth) + index; - let index = pos / 2; - - match self.paths.get(&index) { - None => Err(MerkleError::NodeNotInSet(index)), - Some(path) => { - let mut local_path = path.clone(); - if is_even(pos) { - local_path.remove(0); - Ok(local_path) - } else { - local_path.remove(1); - Ok(local_path) - } - } - } - } - - /// Replaces the leaf at the specified index with the provided value. - /// - /// # Errors - /// Returns an error if: - /// * Requested node does not exist in the set. - pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> { - let depth = self.depth(); - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); - } - let pos = 2u64.pow(depth) + index; - - let path = match self.paths.get_mut(&(pos / 2)) { - None => return Err(MerkleError::NodeNotInSet(index)), - Some(path) => path, - }; - - // Fill old_hashes vector ----------------------------------------------------------------- - let (old_hashes, _) = compute_path_trace(path, depth, index); - - // Fill new_hashes vector ----------------------------------------------------------------- - if is_even(pos) { - path[0] = value; - } else { - path[1] = value; - } - - let (new_hashes, new_root) = compute_path_trace(path, depth, index); - self.root = new_root; - - // update paths --------------------------------------------------------------------------- - for path in self.paths.values_mut() { - for i in (0..old_hashes.len()).rev() { - if path[i + 2] == old_hashes[i] { - path[i + 2] = new_hashes[i]; - break; - } - } - } - - Ok(()) - } -} - -// HELPER FUNCTIONS -// -------------------------------------------------------------------------------------------- - -fn is_even(pos: u64) -> bool { - pos & 1 == 0 -} - -/// Calculates the hash of the parent node by two sibling ones -/// - node — current node -/// - node_pos — position of the current node -/// - sibling — neighboring vertex in the tree -fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word { - if is_even(node_pos) { - Rpo256::merge(&[node.into(), sibling.into()]).into() - } else { - Rpo256::merge(&[sibling.into(), node.into()]).into() - } -} - -/// Returns vector of hashes from current to the root -fn compute_path_trace(path: &[Word], depth: u32, index: u64) -> (Vec, Word) { - let mut pos = 2u64.pow(depth) + index; - - let mut computed_hashes = Vec::::new(); - - let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); - - if path.len() != 2 { - for path_hash in path.iter().skip(2) { - computed_hashes.push(comp_hash); - pos /= 2; - comp_hash = calculate_parent_hash(comp_hash, pos, *path_hash); - } - } - - (computed_hashes, comp_hash) -} - -/// Returns hash of the root -fn compute_path_root(path: &[Word], depth: u32, index: u64) -> Word { - let mut pos = 2u64.pow(depth) + index; - - // hash that is obtained after calculating the current hash and path hash - let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); - - for path_hash in path.iter().skip(2) { - pos /= 2; - comp_hash = calculate_parent_hash(comp_hash, pos, *path_hash); - } - - comp_hash -} - -// TESTS -// ================================================================================================ - -#[cfg(test)] -mod tests { - use super::calculate_parent_hash; - use crate::merkle::int_to_node; - - #[test] - fn get_root() { - let leaf0 = int_to_node(0); - let leaf1 = int_to_node(1); - let leaf2 = int_to_node(2); - let leaf3 = int_to_node(3); - - let parent0 = calculate_parent_hash(leaf0, 0, leaf1); - let parent1 = calculate_parent_hash(leaf2, 2, leaf3); - - let root_exp = calculate_parent_hash(parent0, 0, parent1); - - let mut set = super::MerklePathSet::new(3).unwrap(); - - set.add_path(0, leaf0, vec![leaf1, parent1]).unwrap(); - - assert_eq!(set.root(), root_exp); - } - - #[test] - fn add_and_get_path() { - let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_node(6); - let index = 6u64; - let depth = 4u32; - let mut set = super::MerklePathSet::new(depth).unwrap(); - - set.add_path(index, hash_6, path_6.clone()).unwrap(); - let stored_path_6 = set.get_path(depth, index).unwrap(); - - assert_eq!(path_6, stored_path_6); - assert!(set.get_path(depth, 15u64).is_err()) - } - - #[test] - fn get_node() { - let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_node(6); - let index = 6u64; - let depth = 4u32; - let mut set = super::MerklePathSet::new(depth).unwrap(); - - set.add_path(index, hash_6, path_6).unwrap(); - - assert_eq!(int_to_node(6u64), set.get_node(depth, index).unwrap()); - assert!(set.get_node(depth, 15u64).is_err()); - } - - #[test] - fn update_leaf() { - let hash_4 = int_to_node(4); - let hash_5 = int_to_node(5); - let hash_6 = int_to_node(6); - let hash_7 = int_to_node(7); - let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5); - let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7); - - let hash_0123 = int_to_node(123); - - let path_6 = vec![hash_7, hash_45, hash_0123]; - let path_5 = vec![hash_4, hash_67, hash_0123]; - let path_4 = vec![hash_5, hash_67, hash_0123]; - - let index_6 = 6u64; - let index_5 = 5u64; - let index_4 = 4u64; - let depth = 4u32; - let mut set = super::MerklePathSet::new(depth).unwrap(); - - set.add_path(index_6, hash_6, path_6).unwrap(); - set.add_path(index_5, hash_5, path_5).unwrap(); - set.add_path(index_4, hash_4, path_4).unwrap(); - - let new_hash_6 = int_to_node(100); - let new_hash_5 = int_to_node(55); - - set.update_leaf(index_6, new_hash_6).unwrap(); - let new_path_4 = set.get_path(depth, index_4).unwrap(); - let new_hash_67 = calculate_parent_hash(new_hash_6, 14u64, hash_7); - assert_eq!(new_hash_67, new_path_4[1]); - - set.update_leaf(index_5, new_hash_5).unwrap(); - let new_path_4 = set.get_path(depth, index_4).unwrap(); - let new_path_6 = set.get_path(depth, index_6).unwrap(); - let new_hash_45 = calculate_parent_hash(new_hash_5, 13u64, hash_4); - assert_eq!(new_hash_45, new_path_6[1]); - assert_eq!(new_hash_5, new_path_4[0]); - } -} diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index 5eff1c35..e9c53ea1 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -1,4 +1,4 @@ -use super::{Felt, MerkleError, Rpo256, RpoDigest, Vec, Word}; +use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word}; use crate::{utils::uninit_vector, FieldElement}; use core::slice; use winter_math::log2; @@ -22,7 +22,7 @@ impl MerkleTree { pub fn new(leaves: Vec) -> Result { let n = leaves.len(); if n <= 1 { - return Err(MerkleError::DepthTooSmall(n as u32)); + return Err(MerkleError::DepthTooSmall(n as u8)); } else if !n.is_power_of_two() { return Err(MerkleError::NumLeavesNotPowerOfTwo(n)); } @@ -35,12 +35,14 @@ impl MerkleTree { nodes[n..].copy_from_slice(&leaves); // re-interpret nodes as an array of two nodes fused together - let two_nodes = - unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) }; + // Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e. + // `self`). + let ptr = nodes.as_ptr() as *const [RpoDigest; 2]; + let pairs = unsafe { slice::from_raw_parts(ptr, n) }; // calculate all internal tree nodes for i in (1..n).rev() { - nodes[i] = Rpo256::merge(&two_nodes[i]).into(); + nodes[i] = Rpo256::merge(&pairs[i]).into(); } Ok(Self { nodes }) @@ -57,78 +59,93 @@ impl MerkleTree { /// Returns the depth of this Merkle tree. /// /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. - pub fn depth(&self) -> u32 { - log2(self.nodes.len() / 2) + pub fn depth(&self) -> u8 { + log2(self.nodes.len() / 2) as u8 } - /// Returns a node at the specified depth and index. + /// Returns a node at the specified depth and index value. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. - /// * The specified index not valid for the specified depth. - pub fn get_node(&self, depth: u32, index: u64) -> Result { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); - } - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + /// * The specified index is not valid for the specified depth. + pub fn get_node(&self, index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let pos = 2_usize.pow(depth) + (index as usize); + let pos = index.to_scalar_index() as usize; Ok(self.nodes[pos]) } - /// Returns a Merkle path to the node at the specified depth and index. The node itself is - /// not included in the path. + /// Returns a Merkle path to the node at the specified depth and index value. The node itself + /// is not included in the path. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. - /// * The specified index not valid for the specified depth. - pub fn get_path(&self, depth: u32, index: u64) -> Result, MerkleError> { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); - } - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + /// * The specified value is not valid for the specified depth. + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let mut path = Vec::with_capacity(depth as usize); - let mut pos = 2_usize.pow(depth) + (index as usize); - - while pos > 1 { - path.push(self.nodes[pos ^ 1]); - pos >>= 1; + // TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so + // we always use inlined `for` instead of `while`? the reason to use `for` is because its + // easier for the compiler to vectorize. + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let sibling = index.sibling().to_scalar_index() as usize; + path.push(self.nodes[sibling]); + index.move_up(); } - Ok(path) + Ok(path.into()) } /// Replaces the leaf at the specified index with the provided value. /// /// # Errors - /// Returns an error if the specified index is not a valid leaf index for this tree. - pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> { + /// Returns an error if the specified index value is not a valid leaf value for this tree. + pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> { let depth = self.depth(); - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + let mut index = NodeIndex::new(depth, index_value); + if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let mut index = 2usize.pow(depth) + index as usize; - self.nodes[index] = value; - + // we don't need to copy the pairs into a new address as we are logically guaranteed to not + // overlap write instructions. however, it's important to bind the lifetime of pairs to + // `self.nodes` so the compiler will never move one without moving the other. + debug_assert_eq!(self.nodes.len() & 1, 0); let n = self.nodes.len() / 2; - let two_nodes = - unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) }; - for _ in 0..depth { - index /= 2; - self.nodes[index] = Rpo256::merge(&two_nodes[index]).into(); + // Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of + // digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that + // `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is + // logically guaranteed to not overlap write positions as the write index is always half + // the index from which we read the digest input. + let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2]; + let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) }; + + // update the current node + let pos = index.to_scalar_index() as usize; + self.nodes[pos] = value; + + // traverse to the root, updating each node with the merged values of its parents + for _ in 0..index.depth() { + index.move_up(); + let pos = index.to_scalar_index() as usize; + let value = Rpo256::merge(&pairs[pos]).into(); + self.nodes[pos] = value; } Ok(()) @@ -140,10 +157,10 @@ impl MerkleTree { #[cfg(test)] mod tests { - use super::{ - super::{int_to_node, Rpo256}, - Word, - }; + use super::*; + use crate::merkle::int_to_node; + use core::mem::size_of; + use proptest::prelude::*; const LEAVES4: [Word; 4] = [ int_to_node(1), @@ -187,16 +204,16 @@ mod tests { let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap(); // check depth 2 - assert_eq!(LEAVES4[0], tree.get_node(2, 0).unwrap()); - assert_eq!(LEAVES4[1], tree.get_node(2, 1).unwrap()); - assert_eq!(LEAVES4[2], tree.get_node(2, 2).unwrap()); - assert_eq!(LEAVES4[3], tree.get_node(2, 3).unwrap()); + assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap()); + assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap()); + assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap()); + assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap()); // check depth 1 let (_, node2, node3) = compute_internal_nodes(); - assert_eq!(node2, tree.get_node(1, 0).unwrap()); - assert_eq!(node3, tree.get_node(1, 1).unwrap()); + assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -206,14 +223,26 @@ mod tests { let (_, node2, node3) = compute_internal_nodes(); // check depth 2 - assert_eq!(vec![LEAVES4[1], node3], tree.get_path(2, 0).unwrap()); - assert_eq!(vec![LEAVES4[0], node3], tree.get_path(2, 1).unwrap()); - assert_eq!(vec![LEAVES4[3], node2], tree.get_path(2, 2).unwrap()); - assert_eq!(vec![LEAVES4[2], node2], tree.get_path(2, 3).unwrap()); + assert_eq!( + vec![LEAVES4[1], node3], + *tree.get_path(NodeIndex::new(2, 0)).unwrap() + ); + assert_eq!( + vec![LEAVES4[0], node3], + *tree.get_path(NodeIndex::new(2, 1)).unwrap() + ); + assert_eq!( + vec![LEAVES4[3], node2], + *tree.get_path(NodeIndex::new(2, 2)).unwrap() + ); + assert_eq!( + vec![LEAVES4[2], node2], + *tree.get_path(NodeIndex::new(2, 3)).unwrap() + ); // check depth 1 - assert_eq!(vec![node3], tree.get_path(1, 0).unwrap()); - assert_eq!(vec![node2], tree.get_path(1, 1).unwrap()); + assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -221,25 +250,53 @@ mod tests { let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap(); // update one leaf - let index = 3; + let value = 3; let new_node = int_to_node(9); let mut expected_leaves = LEAVES8.to_vec(); - expected_leaves[index as usize] = new_node; + expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); - tree.update_leaf(index, new_node).unwrap(); + tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); // update another leaf - let index = 6; + let value = 6; let new_node = int_to_node(10); - expected_leaves[index as usize] = new_node; + expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); - tree.update_leaf(index, new_node).unwrap(); + tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); } + proptest! { + #[test] + fn arbitrary_word_can_be_represented_as_digest( + a in prop::num::u64::ANY, + b in prop::num::u64::ANY, + c in prop::num::u64::ANY, + d in prop::num::u64::ANY, + ) { + // this test will assert the memory equivalence between word and digest. + // it is used to safeguard the `[MerkleTee::update_leaf]` implementation + // that assumes this equivalence. + + // build a word and copy it to another address as digest + let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]; + let digest = RpoDigest::from(word); + + // assert the addresses are different + let word_ptr = (&word).as_ptr() as *const u8; + let digest_ptr = (&digest).as_ptr() as *const u8; + assert_ne!(word_ptr, digest_ptr); + + // compare the bytes representation + let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::()) }; + let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::()) }; + assert_eq!(word_bytes, digest_bytes); + } + } + // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 87cd80f5..0b827522 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,15 +1,21 @@ use super::{ hash::rpo::{Rpo256, RpoDigest}, - utils::collections::{BTreeMap, Vec}, - Felt, Word, ZERO, + utils::collections::{vec, BTreeMap, Vec}, + Felt, StarkField, Word, ZERO, }; use core::fmt; +mod index; +pub use index::NodeIndex; + mod merkle_tree; pub use merkle_tree::MerkleTree; -mod merkle_path_set; -pub use merkle_path_set::MerklePathSet; +mod path; +pub use path::MerklePath; + +mod path_set; +pub use path_set::MerklePathSet; mod simple_smt; pub use simple_smt::SimpleSmt; @@ -19,12 +25,12 @@ pub use simple_smt::SimpleSmt; #[derive(Clone, Debug)] pub enum MerkleError { - DepthTooSmall(u32), - DepthTooBig(u32), + DepthTooSmall(u8), + DepthTooBig(u64), NumLeavesNotPowerOfTwo(usize), - InvalidIndex(u32, u64), - InvalidDepth(u32, u32), - InvalidPath(Vec), + InvalidIndex(NodeIndex), + InvalidDepth { expected: u8, provided: u8 }, + InvalidPath(MerklePath), InvalidEntriesCount(usize, usize), NodeNotInSet(u64), } @@ -38,11 +44,11 @@ impl fmt::Display for MerkleError { NumLeavesNotPowerOfTwo(leaves) => { write!(f, "the leaves count {leaves} is not a power of 2") } - InvalidIndex(depth, index) => write!( + InvalidIndex(index) => write!( f, - "the leaf index {index} is not valid for the depth {depth}" + "the index value {} is not valid for the depth {}", index.value(), index.depth() ), - InvalidDepth(expected, provided) => write!( + InvalidDepth { expected, provided } => write!( f, "the provided depth {provided} is not valid for {expected}" ), diff --git a/src/merkle/path.rs b/src/merkle/path.rs new file mode 100644 index 00000000..d7edd5dc --- /dev/null +++ b/src/merkle/path.rs @@ -0,0 +1,84 @@ +use super::{vec, NodeIndex, Rpo256, Vec, Word}; +use core::ops::{Deref, DerefMut}; + +// MERKLE PATH +// ================================================================================================ + +/// A merkle path container, composed of a sequence of nodes of a Merkle tree. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct MerklePath { + nodes: Vec, +} + +impl MerklePath { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Creates a new Merkle path from a list of nodes. + pub fn new(nodes: Vec) -> Self { + Self { nodes } + } + + // PROVIDERS + // -------------------------------------------------------------------------------------------- + + /// Computes the merkle root for this opening. + pub fn compute_root(&self, index_value: u64, node: Word) -> Word { + let mut index = NodeIndex::new(self.depth(), index_value); + self.nodes.iter().copied().fold(node, |node, sibling| { + // compute the node and move to the next iteration. + let input = index.build_node(node.into(), sibling.into()); + index.move_up(); + Rpo256::merge(&input).into() + }) + } + + /// Returns the depth in which this Merkle path proof is valid. + pub fn depth(&self) -> u8 { + self.nodes.len() as u8 + } + + /// Verifies the Merkle opening proof towards the provided root. + /// + /// Returns `true` if `node` exists at `index` in a Merkle tree with `root`. + pub fn verify(&self, index: u64, node: Word, root: &Word) -> bool { + root == &self.compute_root(index, node) + } +} + +impl From> for MerklePath { + fn from(path: Vec) -> Self { + Self::new(path) + } +} + +impl Deref for MerklePath { + // we use `Vec` here instead of slice so we can call vector mutation methods directly from the + // merkle path (example: `Vec::remove`). + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.nodes + } +} + +impl DerefMut for MerklePath { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.nodes + } +} + +impl FromIterator for MerklePath { + fn from_iter>(iter: T) -> Self { + Self::new(iter.into_iter().collect()) + } +} + +impl IntoIterator for MerklePath { + type Item = Word; + type IntoIter = vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.nodes.into_iter() + } +} diff --git a/src/merkle/path_set.rs b/src/merkle/path_set.rs new file mode 100644 index 00000000..6acc12c0 --- /dev/null +++ b/src/merkle/path_set.rs @@ -0,0 +1,333 @@ +use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, Vec, Word, ZERO}; + +// MERKLE PATH SET +// ================================================================================================ + +/// A set of Merkle paths. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MerklePathSet { + root: Word, + total_depth: u8, + paths: BTreeMap, +} + +impl MerklePathSet { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + + /// Returns an empty MerklePathSet. + pub fn new(depth: u8) -> Result { + let root = [ZERO; 4]; + let paths = BTreeMap::new(); + + Ok(Self { + root, + total_depth: depth, + paths, + }) + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the root to which all paths in this set resolve. + pub const fn root(&self) -> Word { + self.root + } + + /// Returns the depth of the Merkle tree implied by the paths stored in this set. + /// + /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. + pub const fn depth(&self) -> u8 { + self.total_depth + } + + /// Returns a node at the specified index. + /// + /// # Errors + /// Returns an error if: + /// * The specified index is not valid for the depth of structure. + /// * Requested node does not exist in the set. + pub fn get_node(&self, index: NodeIndex) -> Result { + if !index.with_depth(self.total_depth).is_valid() { + return Err(MerkleError::InvalidIndex( + index.with_depth(self.total_depth), + )); + } + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); + } + + let index_value = index.to_scalar_index(); + let parity = index_value & 1; + let index_value = index_value / 2; + self.paths + .get(&index_value) + .ok_or(MerkleError::NodeNotInSet(index_value)) + .map(|path| path[parity as usize]) + } + + /// Returns a Merkle path to the node at the specified index. The node itself is + /// not included in the path. + /// + /// # Errors + /// Returns an error if: + /// * The specified index is not valid for the depth of structure. + /// * Node of the requested path does not exist in the set. + pub fn get_path(&self, index: NodeIndex) -> Result { + if !index.with_depth(self.total_depth).is_valid() { + return Err(MerkleError::InvalidIndex(index)); + } + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); + } + + let index_value = index.to_scalar_index(); + let index = index_value / 2; + let parity = index_value & 1; + let mut path = self + .paths + .get(&index) + .cloned() + .ok_or(MerkleError::NodeNotInSet(index))?; + path.remove(parity as usize); + Ok(path) + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters + /// specify the leaf node at which the path starts. + /// + /// # Errors + /// Returns an error if: + /// - The specified index is is not valid in the context of this Merkle path set (i.e., the + /// index implies a greater depth than is specified for this set). + /// - The specified path is not consistent with other paths in the set (i.e., resolves to a + /// different root). + pub fn add_path( + &mut self, + index_value: u64, + value: Word, + mut path: MerklePath, + ) -> Result<(), MerkleError> { + let depth = (path.len() + 1) as u8; + let mut index = NodeIndex::new(depth, index_value); + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); + } + + // update the current path + let index_value = index.to_scalar_index(); + let upper_index_value = index_value / 2; + let parity = index_value & 1; + path.insert(parity as usize, value); + + // traverse to the root, updating the nodes + let root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + let root = path.iter().skip(2).copied().fold(root, |root, hash| { + index.move_up(); + Rpo256::merge(&index.build_node(root.into(), hash.into())).into() + }); + + // if the path set is empty (the root is all ZEROs), set the root to the root of the added + // path; otherwise, the root of the added path must be identical to the current root + if self.root == [ZERO; 4] { + self.root = root; + } else if self.root != root { + return Err(MerkleError::InvalidPath(path)); + } + + // finish updating the path + self.paths.insert(upper_index_value, path); + Ok(()) + } + + /// Replaces the leaf at the specified index with the provided value. + /// + /// # Errors + /// Returns an error if: + /// * Requested node does not exist in the set. + pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> { + let depth = self.depth(); + let mut index = NodeIndex::new(depth, base_index_value); + if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); + } + + let path = match self + .paths + .get_mut(&index.clone().move_up().to_scalar_index()) + { + Some(path) => path, + None => return Err(MerkleError::NodeNotInSet(base_index_value)), + }; + + // Fill old_hashes vector ----------------------------------------------------------------- + let mut current_index = index; + let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2)); + let mut root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + for hash in path.iter().skip(2).copied() { + old_hashes.push(root); + current_index.move_up(); + let input = current_index.build_node(hash.into(), root.into()); + root = Rpo256::merge(&input).into(); + } + + // Fill new_hashes vector ----------------------------------------------------------------- + path[index.is_value_odd() as usize] = value; + + let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2)); + let mut new_root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + for path_hash in path.iter().skip(2).copied() { + new_hashes.push(new_root); + index.move_up(); + let input = current_index.build_node(path_hash.into(), new_root.into()); + new_root = Rpo256::merge(&input).into(); + } + + self.root = new_root; + + // update paths --------------------------------------------------------------------------- + for path in self.paths.values_mut() { + for i in (0..old_hashes.len()).rev() { + if path[i + 2] == old_hashes[i] { + path[i + 2] = new_hashes[i]; + break; + } + } + } + + Ok(()) + } +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use crate::merkle::int_to_node; + + #[test] + fn get_root() { + let leaf0 = int_to_node(0); + let leaf1 = int_to_node(1); + let leaf2 = int_to_node(2); + let leaf3 = int_to_node(3); + + let parent0 = calculate_parent_hash(leaf0, 0, leaf1); + let parent1 = calculate_parent_hash(leaf2, 2, leaf3); + + let root_exp = calculate_parent_hash(parent0, 0, parent1); + + let mut set = super::MerklePathSet::new(3).unwrap(); + + set.add_path(0, leaf0, vec![leaf1, parent1].into()).unwrap(); + + assert_eq!(set.root(), root_exp); + } + + #[test] + fn add_and_get_path() { + let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; + let hash_6 = int_to_node(6); + let index = 6_u64; + let depth = 4_u8; + let mut set = super::MerklePathSet::new(depth).unwrap(); + + set.add_path(index, hash_6, path_6.clone().into()).unwrap(); + let stored_path_6 = set.get_path(NodeIndex::new(depth, index)).unwrap(); + + assert_eq!(path_6, *stored_path_6); + assert!(set.get_path(NodeIndex::new(depth, 15_u64)).is_err()) + } + + #[test] + fn get_node() { + let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; + let hash_6 = int_to_node(6); + let index = 6_u64; + let depth = 4_u8; + let mut set = MerklePathSet::new(depth).unwrap(); + + set.add_path(index, hash_6, path_6.into()).unwrap(); + + assert_eq!( + int_to_node(6u64), + set.get_node(NodeIndex::new(depth, index)).unwrap() + ); + assert!(set.get_node(NodeIndex::new(depth, 15_u64)).is_err()); + } + + #[test] + fn update_leaf() { + let hash_4 = int_to_node(4); + let hash_5 = int_to_node(5); + let hash_6 = int_to_node(6); + let hash_7 = int_to_node(7); + let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5); + let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7); + + let hash_0123 = int_to_node(123); + + let path_6 = vec![hash_7, hash_45, hash_0123]; + let path_5 = vec![hash_4, hash_67, hash_0123]; + let path_4 = vec![hash_5, hash_67, hash_0123]; + + let index_6 = 6_u64; + let index_5 = 5_u64; + let index_4 = 4_u64; + let depth = 4_u8; + let mut set = MerklePathSet::new(depth).unwrap(); + + set.add_path(index_6, hash_6, path_6.into()).unwrap(); + set.add_path(index_5, hash_5, path_5.into()).unwrap(); + set.add_path(index_4, hash_4, path_4.into()).unwrap(); + + let new_hash_6 = int_to_node(100); + let new_hash_5 = int_to_node(55); + + set.update_leaf(index_6, new_hash_6).unwrap(); + let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap(); + let new_hash_67 = calculate_parent_hash(new_hash_6, 14_u64, hash_7); + assert_eq!(new_hash_67, new_path_4[1]); + + set.update_leaf(index_5, new_hash_5).unwrap(); + let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap(); + let new_path_6 = set.get_path(NodeIndex::new(depth, index_6)).unwrap(); + let new_hash_45 = calculate_parent_hash(new_hash_5, 13_u64, hash_4); + assert_eq!(new_hash_45, new_path_6[1]); + assert_eq!(new_hash_5, new_path_4[0]); + } + + // HELPER FUNCTIONS + // -------------------------------------------------------------------------------------------- + + const fn is_even(pos: u64) -> bool { + pos & 1 == 0 + } + + /// Calculates the hash of the parent node by two sibling ones + /// - node — current node + /// - node_pos — position of the current node + /// - sibling — neighboring vertex in the tree + fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word { + if is_even(node_pos) { + Rpo256::merge(&[node.into(), sibling.into()]).into() + } else { + Rpo256::merge(&[sibling.into(), node.into()]).into() + } + } +} diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 821b1a6f..186ac259 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,4 +1,4 @@ -use super::{BTreeMap, MerkleError, Rpo256, RpoDigest, Vec, Word}; +use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word}; #[cfg(test)] mod tests; @@ -12,7 +12,7 @@ mod tests; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SimpleSmt { root: Word, - depth: u32, + depth: u8, store: Store, } @@ -21,10 +21,10 @@ impl SimpleSmt { // -------------------------------------------------------------------------------------------- /// Minimum supported depth. - pub const MIN_DEPTH: u32 = 1; + pub const MIN_DEPTH: u8 = 1; /// Maximum supported depth. - pub const MAX_DEPTH: u32 = 63; + pub const MAX_DEPTH: u8 = 63; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -37,7 +37,7 @@ impl SimpleSmt { /// /// The function will fail if the provided entries count exceed the maximum tree capacity, that /// is `2^{depth}`. - pub fn new(entries: R, depth: u32) -> Result + pub fn new(entries: R, depth: u8) -> Result where R: IntoIterator, I: Iterator + ExactSizeIterator, @@ -49,7 +49,7 @@ impl SimpleSmt { if depth < Self::MIN_DEPTH { return Err(MerkleError::DepthTooSmall(depth)); } else if Self::MAX_DEPTH < depth { - return Err(MerkleError::DepthTooBig(depth)); + return Err(MerkleError::DepthTooBig(depth as u64)); } else if entries.len() > max { return Err(MerkleError::InvalidEntriesCount(max, entries.len())); } @@ -67,7 +67,7 @@ impl SimpleSmt { } /// Returns the depth of this Merkle tree. - pub const fn depth(&self) -> u32 { + pub const fn depth(&self) -> u8 { self.depth } @@ -82,15 +82,15 @@ impl SimpleSmt { /// Returns an error if: /// * The specified depth is greater than the depth of the tree. /// * The specified key does not exist - pub fn get_node(&self, depth: u32, key: u64) -> Result { - if depth == 0 { - Err(MerkleError::DepthTooSmall(depth)) - } else if depth > self.depth() { - Err(MerkleError::DepthTooBig(depth)) - } else if depth == self.depth() { - self.store.get_leaf_node(key) + pub fn get_node(&self, index: &NodeIndex) -> Result { + if index.is_root() { + Err(MerkleError::DepthTooSmall(index.depth())) + } else if index.depth() > self.depth() { + Err(MerkleError::DepthTooBig(index.depth() as u64)) + } else if index.depth() == self.depth() { + self.store.get_leaf_node(index.value()) } else { - let branch_node = self.store.get_branch_node(key, depth)?; + let branch_node = self.store.get_branch_node(index)?; Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into()) } } @@ -102,29 +102,25 @@ impl SimpleSmt { /// Returns an error if: /// * The specified key does not exist as a branch or leaf node /// * The specified depth is greater than the depth of the tree. - pub fn get_path(&self, depth: u32, key: u64) -> Result, MerkleError> { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); - } else if depth == self.depth() && !self.store.check_leaf_node_exists(key) { - return Err(MerkleError::InvalidIndex(self.depth(), key)); + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else if index.depth() == self.depth() && !self.store.check_leaf_node_exists(index.value()) + { + return Err(MerkleError::InvalidIndex(index.with_depth(self.depth()))); } - let mut path = Vec::with_capacity(depth as usize); - let mut curr_key = key; - for n in (0..depth).rev() { - let parent_key = curr_key >> 1; - let parent_node = self.store.get_branch_node(parent_key, n)?; - let sibling_node = if curr_key & 1 == 1 { - parent_node.left - } else { - parent_node.right - }; - path.push(sibling_node.into()); - curr_key >>= 1; + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let BranchNode { left, right } = self.store.get_branch_node(&index)?; + let value = if is_right { left } else { right }; + path.push(*value); } - Ok(path) + Ok(path.into()) } /// Return a Merkle path from the leaf at the specified key to the root. The leaf itself is not @@ -133,8 +129,8 @@ impl SimpleSmt { /// # Errors /// Returns an error if: /// * The specified key does not exist as a leaf node. - pub fn get_leaf_path(&self, key: u64) -> Result, MerkleError> { - self.get_path(self.depth(), key) + pub fn get_leaf_path(&self, key: u64) -> Result { + self.get_path(NodeIndex::new(self.depth(), key)) } /// Replaces the leaf located at the specified key, and recomputes hashes by walking up the tree @@ -143,7 +139,7 @@ impl SimpleSmt { /// Returns an error if the specified key is not a valid leaf index for this tree. pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { if !self.store.check_leaf_node_exists(key) { - return Err(MerkleError::InvalidIndex(self.depth(), key)); + return Err(MerkleError::InvalidIndex(NodeIndex::new(self.depth(), key))); } self.insert_leaf(key, value)?; @@ -154,27 +150,25 @@ impl SimpleSmt { pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { self.store.insert_leaf_node(key, value); - let depth = self.depth(); - let mut curr_key = key; - let mut curr_node: RpoDigest = value.into(); - for n in (0..depth).rev() { - let parent_key = curr_key >> 1; - let parent_node = self + // TODO consider using a map `index |-> word` instead of `index |-> (word, word)` + let mut index = NodeIndex::new(self.depth(), key); + let mut value = RpoDigest::from(value); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let BranchNode { left, right } = self .store - .get_branch_node(parent_key, n) - .unwrap_or_else(|_| self.store.get_empty_node((n + 1) as usize)); - let (left, right) = if curr_key & 1 == 1 { - (parent_node.left, curr_node) + .get_branch_node(&index) + .unwrap_or_else(|_| self.store.get_empty_node(index.depth() as usize + 1)); + let (left, right) = if is_right { + (left, value) } else { - (curr_node, parent_node.right) + (value, right) }; - - self.store.insert_branch_node(parent_key, n, left, right); - curr_key = parent_key; - curr_node = Rpo256::merge(&[left, right]); + self.store.insert_branch_node(index, left, right); + value = Rpo256::merge(&[left, right]); } - self.root = curr_node.into(); - + self.root = value.into(); Ok(()) } } @@ -188,10 +182,10 @@ impl SimpleSmt { /// with the root hash of an empty tree, and ending with the zero value of a leaf node. #[derive(Debug, Clone, PartialEq, Eq)] struct Store { - branches: BTreeMap<(u64, u32), BranchNode>, + branches: BTreeMap, leaves: BTreeMap, empty_hashes: Vec, - depth: u32, + depth: u8, } #[derive(Debug, Default, Clone, PartialEq, Eq)] @@ -201,7 +195,7 @@ struct BranchNode { } impl Store { - fn new(depth: u32) -> (Self, Word) { + fn new(depth: u8) -> (Self, Word) { let branches = BTreeMap::new(); let leaves = BTreeMap::new(); @@ -244,23 +238,23 @@ impl Store { self.leaves .get(&key) .cloned() - .ok_or(MerkleError::InvalidIndex(self.depth, key)) + .ok_or(MerkleError::InvalidIndex(NodeIndex::new(self.depth, key))) } fn insert_leaf_node(&mut self, key: u64, node: Word) { self.leaves.insert(key, node); } - fn get_branch_node(&self, key: u64, depth: u32) -> Result { + fn get_branch_node(&self, index: &NodeIndex) -> Result { self.branches - .get(&(key, depth)) + .get(index) .cloned() - .ok_or(MerkleError::InvalidIndex(depth, key)) + .ok_or(MerkleError::InvalidIndex(*index)) } - fn insert_branch_node(&mut self, key: u64, depth: u32, left: RpoDigest, right: RpoDigest) { - let node = BranchNode { left, right }; - self.branches.insert((key, depth), node); + fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) { + let branch = BranchNode { left, right }; + self.branches.insert(index, branch); } fn leaves_count(&self) -> usize { diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 7042d1b6..2096fd1c 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -1,6 +1,6 @@ use super::{ super::{MerkleTree, RpoDigest, SimpleSmt}, - Rpo256, Vec, Word, + NodeIndex, Rpo256, Vec, Word, }; use crate::{Felt, FieldElement}; use core::iter; @@ -62,7 +62,10 @@ fn build_sparse_tree() { .expect("Failed to insert leaf"); let mt2 = MerkleTree::new(values.clone()).unwrap(); assert_eq!(mt2.root(), smt.root()); - assert_eq!(mt2.get_path(3, 6).unwrap(), smt.get_path(3, 6).unwrap()); + assert_eq!( + mt2.get_path(NodeIndex::new(3, 6)).unwrap(), + smt.get_path(NodeIndex::new(3, 6)).unwrap() + ); // insert second value at distinct leaf branch let key = 2; @@ -72,7 +75,10 @@ fn build_sparse_tree() { .expect("Failed to insert leaf"); let mt3 = MerkleTree::new(values).unwrap(); assert_eq!(mt3.root(), smt.root()); - assert_eq!(mt3.get_path(3, 2).unwrap(), smt.get_path(3, 2).unwrap()); + assert_eq!( + mt3.get_path(NodeIndex::new(3, 2)).unwrap(), + smt.get_path(NodeIndex::new(3, 2)).unwrap() + ); } #[test] @@ -81,8 +87,8 @@ fn build_full_tree() { let (root, node2, node3) = compute_internal_nodes(); assert_eq!(root, tree.root()); - assert_eq!(node2, tree.get_node(1, 0).unwrap()); - assert_eq!(node3, tree.get_node(1, 1).unwrap()); + assert_eq!(node2, tree.get_node(&NodeIndex::new(1, 0)).unwrap()); + assert_eq!(node3, tree.get_node(&NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -90,10 +96,10 @@ fn get_values() { let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap(); // check depth 2 - assert_eq!(VALUES4[0], tree.get_node(2, 0).unwrap()); - assert_eq!(VALUES4[1], tree.get_node(2, 1).unwrap()); - assert_eq!(VALUES4[2], tree.get_node(2, 2).unwrap()); - assert_eq!(VALUES4[3], tree.get_node(2, 3).unwrap()); + assert_eq!(VALUES4[0], tree.get_node(&NodeIndex::new(2, 0)).unwrap()); + assert_eq!(VALUES4[1], tree.get_node(&NodeIndex::new(2, 1)).unwrap()); + assert_eq!(VALUES4[2], tree.get_node(&NodeIndex::new(2, 2)).unwrap()); + assert_eq!(VALUES4[3], tree.get_node(&NodeIndex::new(2, 3)).unwrap()); } #[test] @@ -103,14 +109,26 @@ fn get_path() { let (_, node2, node3) = compute_internal_nodes(); // check depth 2 - assert_eq!(vec![VALUES4[1], node3], tree.get_path(2, 0).unwrap()); - assert_eq!(vec![VALUES4[0], node3], tree.get_path(2, 1).unwrap()); - assert_eq!(vec![VALUES4[3], node2], tree.get_path(2, 2).unwrap()); - assert_eq!(vec![VALUES4[2], node2], tree.get_path(2, 3).unwrap()); + assert_eq!( + vec![VALUES4[1], node3], + *tree.get_path(NodeIndex::new(2, 0)).unwrap() + ); + assert_eq!( + vec![VALUES4[0], node3], + *tree.get_path(NodeIndex::new(2, 1)).unwrap() + ); + assert_eq!( + vec![VALUES4[3], node2], + *tree.get_path(NodeIndex::new(2, 2)).unwrap() + ); + assert_eq!( + vec![VALUES4[2], node2], + *tree.get_path(NodeIndex::new(2, 3)).unwrap() + ); // check depth 1 - assert_eq!(vec![node3], tree.get_path(1, 0).unwrap()); - assert_eq!(vec![node2], tree.get_path(1, 1).unwrap()); + assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -175,7 +193,7 @@ fn small_tree_opening_is_consistent() { assert_eq!(tree.root(), Word::from(k)); - let cases: Vec<(u32, u64, Vec)> = vec![ + let cases: Vec<(u8, u64, Vec)> = vec![ (3, 0, vec![b, f, j]), (3, 1, vec![a, f, j]), (3, 4, vec![z, h, i]), @@ -189,9 +207,9 @@ fn small_tree_opening_is_consistent() { ]; for (depth, key, path) in cases { - let opening = tree.get_path(depth, key).unwrap(); + let opening = tree.get_path(NodeIndex::new(depth, key)).unwrap(); - assert_eq!(path, opening); + assert_eq!(path, *opening); } } @@ -213,7 +231,7 @@ proptest! { // traverse to root, fetching all paths for d in 1..depth { let k = key >> (depth - d); - tree.get_path(d, k).unwrap(); + tree.get_path(NodeIndex::new(d, k)).unwrap(); } }