Skip to content

Commit

Permalink
[Leaf Counter] Adding a cryptographically secure non-empty leaf count…
Browse files Browse the repository at this point in the history
…er (#46)

- **What?** Introduce `count` to get the number of non-empty leaf nodes in every sub-trie
- **How?** Very similar to `sum` in its implementation
- **Why?** Needed to implement Relay Mining and get a reference value for the number of requests, not necessary their cost in terms of price or compute (i.e. compute units, weight, sum, etc...)
- **Other**: This needs to be part of the commitment and irrespective of the underlying key-value store engine

---

Signed-off-by: Daniel Olshansky <[email protected]>
Co-authored-by: Redouane Lakrache <[email protected]>
Co-authored-by: h5law <[email protected]>
  • Loading branch information
3 people authored Jun 13, 2024
1 parent 921f308 commit cece51c
Show file tree
Hide file tree
Showing 17 changed files with 409 additions and 141 deletions.
16 changes: 9 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
SHELL := /bin/sh

.SILENT:

#####################
Expand Down Expand Up @@ -61,31 +63,31 @@ go_docs: check_godoc ## Generate documentation for the project

.PHONY: benchmark_all
benchmark_all: ## runs all benchmarks
go test -tags=benchmark -benchmem -run=^$ -bench Benchmark ./benchmarks -timeout 0
go test -tags=benchmark -benchmem -run=^$$ -bench Benchmark ./benchmarks -timeout 0

.PHONY: benchmark_smt
benchmark_smt: ## runs all benchmarks for the SMT
go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleTrie ./benchmarks -timeout 0
go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleTrie ./benchmarks -timeout 0

.PHONY: benchmark_smt_fill
benchmark_smt_fill: ## runs a benchmark on filling the SMT with different amounts of values
go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleTrie_Fill ./benchmarks -timeout 0 -benchtime 10x
go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleTrie_Fill ./benchmarks -timeout 0 -benchtime 10x

.PHONY: benchmark_smt_ops
benchmark_smt_ops: ## runs the benchmarks testing different operations on the SMT against different sized tries
go test -tags=benchmark -benchmem -run=^$ -bench='BenchmarkSparseMerkleTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0
go test -tags=benchmark -benchmem -run=^$$ -bench='BenchmarkSparseMerkleTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0

.PHONY: benchmark_smst
benchmark_smst: ## runs all benchmarks for the SMST
go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleSumTrie ./benchmarks -timeout 0
go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleSumTrie ./benchmarks -timeout 0

.PHONY: benchmark_smst_fill
benchmark_smst_fill: ## runs a benchmark on filling the SMST with different amounts of values
go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleSumTrie_Fill ./benchmarks -timeout 0 -benchtime 10x
go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleSumTrie_Fill ./benchmarks -timeout 0 -benchtime 10x

.PHONY: benchmark_smst_ops
benchmark_smst_ops: ## runs the benchmarks test different operations on the SMST against different sized tries
go test -tags=benchmark -benchmem -run=^$ -bench='BenchmarkSparseMerkleSumTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0
go test -tags=benchmark -benchmem -run=^$$ -bench='BenchmarkSparseMerkleSumTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0

.PHONY: benchmark_proof_sizes
benchmark_proof_sizes: ## runs the benchmarks test the proof sizes for different sized tries
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var (
getSMST = func(s *smt.SMST, i uint64) error {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, i)
_, _, err := s.Get(b)
_, _, _, err := s.Get(b)
return err
}
proSMST = func(s *smt.SMST, i uint64) error {
Expand Down
2 changes: 2 additions & 0 deletions docs/merkle-sum-trie.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Sparse Merkle Sum Trie (smst)

TODO(#47): Document the new `count` addition.

<!-- toc -->

- [Sparse Merkle Sum Trie (smst)](#sparse-merkle-sum-trie-smst)
Expand Down
34 changes: 26 additions & 8 deletions hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,24 @@ func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value
// digestSumNode returns the encoded leaf node data as well as its hash (i.e. digest)
func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) {
value = encodeLeafNode(path, data)
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value)

digest = th.digestData(value)
digest = append(digest, value[len(value)-sumSizeBytes:]...)
digest = append(digest, value[firstSumByteIdx:firstCountByteIdx]...)
digest = append(digest, value[firstCountByteIdx:]...)

return
}

// digestSumInnerNode returns the encoded inner node data as well as its hash (i.e. digest)
func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) {
value = encodeSumInnerNode(leftData, rightData)
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value)

digest = th.digestData(value)
digest = append(digest, value[len(value)-sumSizeBytes:]...)
digest = append(digest, value[firstSumByteIdx:firstCountByteIdx]...)
digest = append(digest, value[firstCountByteIdx:]...)

return
}

Expand All @@ -144,17 +152,27 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) {
return
}

// parseSumInnerNode returns the encoded left and right nodes as well as the sum of the current node
func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) {
// parseSumInnerNode returns the encoded left & right nodes, as well as the sum
// and non-empty leaf count in the sub-trie of the current node.
func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum, count uint64) {
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data)

// Extract the sum from the encoded node data
var sumBz [sumSizeBytes]byte
copy(sumBz[:], data[len(data)-sumSizeBytes:])
copy(sumBz[:], data[firstSumByteIdx:firstCountByteIdx])
binary.BigEndian.PutUint64(sumBz[:], sum)

// Extract the count from the encoded node data
var countBz [countSizeBytes]byte
copy(countBz[:], data[firstCountByteIdx:])
binary.BigEndian.PutUint64(countBz[:], count)

// Extract the left and right children
dataWithoutSum := data[:len(data)-sumSizeBytes]
leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBytes]
rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBytes:]
leftIdxLastByte := len(innerNodePrefix) + th.hashSize() + sumSizeBytes + countSizeBytes
dataValue := data[:firstSumByteIdx]
leftData = dataValue[len(innerNodePrefix):leftIdxLastByte]
rightData = dataValue[leftIdxLastByte:]

return
}

Expand Down
54 changes: 37 additions & 17 deletions node_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,40 @@ func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byt

// encodeSumInnerNode encodes an inner node for an smst given the data for both children
func encodeSumInnerNode(leftData, rightData []byte) (data []byte) {
// Compute the sum of the current node
var sum [sumSizeBytes]byte
leftSum := parseSum(leftData)
rightSum := parseSum(rightData)
// TODO_CONSIDERATION: ` I chose BigEndian for readability but most computers
// now are optimized for LittleEndian encoding could be a micro optimization one day.`
binary.BigEndian.PutUint64(sum[:], leftSum+rightSum)
leftSum, leftCount := parseSumAndCount(leftData)
rightSum, rightCount := parseSumAndCount(rightData)

// Compute the SumBz of the current node
var SumBz [sumSizeBytes]byte
binary.BigEndian.PutUint64(SumBz[:], leftSum+rightSum)

// Compute the count of the current node
var countBz [countSizeBytes]byte
binary.BigEndian.PutUint64(countBz[:], leftCount+rightCount)

// Prepare and return the encoded inner node data
data = encodeInnerNode(leftData, rightData)
data = append(data, sum[:]...)
data = append(data, SumBz[:]...)
data = append(data, countBz[:]...)
return
}

// encodeSumExtensionNode encodes the data of a sum extension nodes
// encodeSumExtensionNode encodes the data of a sum extension node
func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) {
// Compute the sum of the current node
var sum [sumSizeBytes]byte
copy(sum[:], childData[len(childData)-sumSizeBytes:])
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(childData)

// Compute the sumBz of the current node
var sumBz [sumSizeBytes]byte
copy(sumBz[:], childData[firstSumByteIdx:firstCountByteIdx])

// Compute the count of the current node
var countBz [countSizeBytes]byte
copy(countBz[:], childData[firstCountByteIdx:])

// Prepare and return the encoded inner node data
data = encodeExtensionNode(pathBounds, path, childData)
data = append(data, sum[:]...)
data = append(data, sumBz[:]...)
data = append(data, countBz[:]...)
return
}

Expand All @@ -114,11 +125,20 @@ func checkPrefix(data, prefix []byte) {
}

// parseSum parses the sum from the encoded node data
func parseSum(data []byte) uint64 {
sum := uint64(0)
sumBz := data[len(data)-sumSizeBytes:]
func parseSumAndCount(data []byte) (sum, count uint64) {
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data)

sumBz := data[firstSumByteIdx:firstCountByteIdx]
if !bytes.Equal(sumBz, defaultEmptySum[:]) {
// TODO_CONSIDERATION: We chose BigEndian for readability but most computers
// now are optimized for LittleEndian encoding could be a micro optimization one day.`
sum = binary.BigEndian.Uint64(sumBz)
}
return sum

countBz := data[firstCountByteIdx:]
if !bytes.Equal(countBz, defaultEmptyCount[:]) {
count = binary.BigEndian.Uint64(countBz)
}

return
}
45 changes: 33 additions & 12 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,15 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error {

// GetValueHash returns the value hash of the closest proof.
func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte {
if proof.ClosestValueHash == nil {
data := proof.ClosestValueHash
if data == nil {
return nil
}
if spec.sumTrie {
return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes]
firstSumByteIdx, _ := getFirstMetaByteIdx(data)
return data[:firstSumByteIdx]
}
return proof.ClosestValueHash
return data
}

func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error {
Expand Down Expand Up @@ -323,22 +325,30 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp
}

// VerifySumProof verifies a Merkle proof for a sum trie.
func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) {
func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum, count uint64, spec *TrieSpec) (bool, error) {
var sumBz [sumSizeBytes]byte
binary.BigEndian.PutUint64(sumBz[:], sum)

var countBz [countSizeBytes]byte
binary.BigEndian.PutUint64(countBz[:], count)

valueHash := spec.valueHash(value)
valueHash = append(valueHash, sumBz[:]...)
valueHash = append(valueHash, countBz[:]...)
if bytes.Equal(value, defaultEmptyValue) && sum == 0 {
valueHash = defaultEmptyValue
}

smtSpec := &TrieSpec{
th: spec.th,
ph: spec.ph,
vh: spec.vh,
sumTrie: spec.sumTrie,
}

nvh := WithValueHasher(nil)
nvh(smtSpec)

return VerifyProof(proof, root, key, valueHash, smtSpec)
}

Expand All @@ -348,26 +358,37 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie
if err := proof.validateBasic(spec); err != nil {
return false, errors.Join(ErrBadProof, err)
}

// Create a new TrieSpec with a nil path hasher.
// Since the ClosestProof already contains a hashed path, double hashing it
// will invalidate the proof.
// Since the ClosestProof already contains a hashed path, double hashing it will invalidate the proof.
nilSpec := &TrieSpec{
th: spec.th,
ph: newNilPathHasher(spec.ph.PathSize()),
vh: spec.vh,
sumTrie: spec.sumTrie,
}

// Verify the closest proof for a basic SMT
if !nilSpec.sumTrie {
return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec)
}

// TODO_DOCUMENT: Understand and explain (in comments) why this case is needed
if proof.ClosestValueHash == nil {
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec)
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, 0, nilSpec)
}
sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBytes:]

data := proof.ClosestValueHash
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data)

sumBz := data[firstSumByteIdx:firstCountByteIdx]
sum := binary.BigEndian.Uint64(sumBz)

valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes]
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec)
countBz := data[firstCountByteIdx:]
count := binary.BigEndian.Uint64(countBz)

valueHash := data[:firstSumByteIdx]
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, count, nilSpec)
}

// verifyProofWithUpdates
Expand Down Expand Up @@ -445,14 +466,14 @@ func VerifyCompactSumProof(
proof *SparseCompactMerkleProof,
root []byte,
key, value []byte,
sum uint64,
sum, count uint64,
spec *TrieSpec,
) (bool, error) {
decompactedProof, err := DecompactProof(proof, spec)
if err != nil {
return false, errors.Join(ErrBadProof, err)
}
return VerifySumProof(decompactedProof, root, key, value, sum, spec)
return VerifySumProof(decompactedProof, root, key, value, sum, count, spec)
}

// VerifyCompactClosestProof is similar to VerifyClosestProof but for a compacted merkle proof
Expand Down
7 changes: 5 additions & 2 deletions proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,12 @@ func randomizeProof(proof *SparseMerkleProof) *SparseMerkleProof {
func randomizeSumProof(proof *SparseMerkleProof) *SparseMerkleProof {
sideNodes := make([][]byte, len(proof.SideNodes))
for i := range sideNodes {
sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBytes)
data := proof.SideNodes[i]
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data)
sideNodes[i] = make([]byte, len(data)-sumSizeBytes-countSizeBytes)
rand.Read(sideNodes[i]) // nolint: errcheck
sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBytes:]...)
sideNodes[i] = append(sideNodes[i], data[firstSumByteIdx:firstCountByteIdx]...)
sideNodes[i] = append(sideNodes[i], data[firstCountByteIdx:]...)
}
return &SparseMerkleProof{
SideNodes: sideNodes,
Expand Down
1 change: 1 addition & 0 deletions root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func TestMerkleRoot_TrieTypes(t *testing.T) {
}
require.NotNil(t, trie.Sum())
require.EqualValues(t, 45, trie.Sum())
require.EqualValues(t, 10, trie.Count())

return
}
Expand Down
Loading

0 comments on commit cece51c

Please sign in to comment.