diff --git a/Makefile b/Makefile index c9d4774..c6c4fc9 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +SHELL := /bin/sh + .SILENT: ##################### @@ -61,31 +63,31 @@ go_docs: check_godoc ## Generate documentation for the project .PHONY: benchmark_all benchmark_all: ## runs all benchmarks - go test -tags=benchmark -benchmem -run=^$ -bench Benchmark ./benchmarks -timeout 0 + go test -tags=benchmark -benchmem -run=^$$ -bench Benchmark ./benchmarks -timeout 0 .PHONY: benchmark_smt benchmark_smt: ## runs all benchmarks for the SMT - go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleTrie ./benchmarks -timeout 0 + go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleTrie ./benchmarks -timeout 0 .PHONY: benchmark_smt_fill benchmark_smt_fill: ## runs a benchmark on filling the SMT with different amounts of values - go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleTrie_Fill ./benchmarks -timeout 0 -benchtime 10x + go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleTrie_Fill ./benchmarks -timeout 0 -benchtime 10x .PHONY: benchmark_smt_ops benchmark_smt_ops: ## runs the benchmarks testing different operations on the SMT against different sized tries - go test -tags=benchmark -benchmem -run=^$ -bench='BenchmarkSparseMerkleTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0 + go test -tags=benchmark -benchmem -run=^$$ -bench='BenchmarkSparseMerkleTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0 .PHONY: benchmark_smst benchmark_smst: ## runs all benchmarks for the SMST - go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleSumTrie ./benchmarks -timeout 0 + go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleSumTrie ./benchmarks -timeout 0 .PHONY: benchmark_smst_fill benchmark_smst_fill: ## runs a benchmark on filling the SMST with different amounts of values - go test -tags=benchmark -benchmem -run=^$ -bench=BenchmarkSparseMerkleSumTrie_Fill ./benchmarks -timeout 0 -benchtime 10x + go test -tags=benchmark -benchmem -run=^$$ -bench=BenchmarkSparseMerkleSumTrie_Fill ./benchmarks -timeout 0 -benchtime 10x .PHONY: benchmark_smst_ops benchmark_smst_ops: ## runs the benchmarks test different operations on the SMST against different sized tries - go test -tags=benchmark -benchmem -run=^$ -bench='BenchmarkSparseMerkleSumTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0 + go test -tags=benchmark -benchmem -run=^$$ -bench='BenchmarkSparseMerkleSumTrie_(Update|Get|Prove|Delete)' ./benchmarks -timeout 0 .PHONY: benchmark_proof_sizes benchmark_proof_sizes: ## runs the benchmarks test the proof sizes for different sized tries diff --git a/benchmarks/bench_utils_test.go b/benchmarks/bench_utils_test.go index 6694086..8557fe4 100644 --- a/benchmarks/bench_utils_test.go +++ b/benchmarks/bench_utils_test.go @@ -38,7 +38,7 @@ var ( getSMST = func(s *smt.SMST, i uint64) error { b := make([]byte, 8) binary.LittleEndian.PutUint64(b, i) - _, _, err := s.Get(b) + _, _, _, err := s.Get(b) return err } proSMST = func(s *smt.SMST, i uint64) error { diff --git a/docs/merkle-sum-trie.md b/docs/merkle-sum-trie.md index ee881db..d54bf0f 100644 --- a/docs/merkle-sum-trie.md +++ b/docs/merkle-sum-trie.md @@ -1,5 +1,7 @@ # Sparse Merkle Sum Trie (smst) +TODO(#47): Document the new `count` addition. + - [Sparse Merkle Sum Trie (smst)](#sparse-merkle-sum-trie-smst) diff --git a/hasher.go b/hasher.go index 3676a49..43f6ca7 100644 --- a/hasher.go +++ b/hasher.go @@ -124,16 +124,24 @@ func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value // digestSumNode returns the encoded leaf node data as well as its hash (i.e. digest) func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) { value = encodeLeafNode(path, data) + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value) + digest = th.digestData(value) - digest = append(digest, value[len(value)-sumSizeBytes:]...) + digest = append(digest, value[firstSumByteIdx:firstCountByteIdx]...) + digest = append(digest, value[firstCountByteIdx:]...) + return } // digestSumInnerNode returns the encoded inner node data as well as its hash (i.e. digest) func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) { value = encodeSumInnerNode(leftData, rightData) + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value) + digest = th.digestData(value) - digest = append(digest, value[len(value)-sumSizeBytes:]...) + digest = append(digest, value[firstSumByteIdx:firstCountByteIdx]...) + digest = append(digest, value[firstCountByteIdx:]...) + return } @@ -144,17 +152,27 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { return } -// parseSumInnerNode returns the encoded left and right nodes as well as the sum of the current node -func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) { +// parseSumInnerNode returns the encoded left & right nodes, as well as the sum +// and non-empty leaf count in the sub-trie of the current node. +func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum, count uint64) { + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + // Extract the sum from the encoded node data var sumBz [sumSizeBytes]byte - copy(sumBz[:], data[len(data)-sumSizeBytes:]) + copy(sumBz[:], data[firstSumByteIdx:firstCountByteIdx]) binary.BigEndian.PutUint64(sumBz[:], sum) + // Extract the count from the encoded node data + var countBz [countSizeBytes]byte + copy(countBz[:], data[firstCountByteIdx:]) + binary.BigEndian.PutUint64(countBz[:], count) + // Extract the left and right children - dataWithoutSum := data[:len(data)-sumSizeBytes] - leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBytes] - rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBytes:] + leftIdxLastByte := len(innerNodePrefix) + th.hashSize() + sumSizeBytes + countSizeBytes + dataValue := data[:firstSumByteIdx] + leftData = dataValue[len(innerNodePrefix):leftIdxLastByte] + rightData = dataValue[leftIdxLastByte:] + return } diff --git a/node_encoders.go b/node_encoders.go index fdc6a06..309c300 100644 --- a/node_encoders.go +++ b/node_encoders.go @@ -80,29 +80,40 @@ func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byt // encodeSumInnerNode encodes an inner node for an smst given the data for both children func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { - // Compute the sum of the current node - var sum [sumSizeBytes]byte - leftSum := parseSum(leftData) - rightSum := parseSum(rightData) - // TODO_CONSIDERATION: ` I chose BigEndian for readability but most computers - // now are optimized for LittleEndian encoding could be a micro optimization one day.` - binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) + leftSum, leftCount := parseSumAndCount(leftData) + rightSum, rightCount := parseSumAndCount(rightData) + + // Compute the SumBz of the current node + var SumBz [sumSizeBytes]byte + binary.BigEndian.PutUint64(SumBz[:], leftSum+rightSum) + + // Compute the count of the current node + var countBz [countSizeBytes]byte + binary.BigEndian.PutUint64(countBz[:], leftCount+rightCount) // Prepare and return the encoded inner node data data = encodeInnerNode(leftData, rightData) - data = append(data, sum[:]...) + data = append(data, SumBz[:]...) + data = append(data, countBz[:]...) return } -// encodeSumExtensionNode encodes the data of a sum extension nodes +// encodeSumExtensionNode encodes the data of a sum extension node func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { - // Compute the sum of the current node - var sum [sumSizeBytes]byte - copy(sum[:], childData[len(childData)-sumSizeBytes:]) + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(childData) + + // Compute the sumBz of the current node + var sumBz [sumSizeBytes]byte + copy(sumBz[:], childData[firstSumByteIdx:firstCountByteIdx]) + + // Compute the count of the current node + var countBz [countSizeBytes]byte + copy(countBz[:], childData[firstCountByteIdx:]) // Prepare and return the encoded inner node data data = encodeExtensionNode(pathBounds, path, childData) - data = append(data, sum[:]...) + data = append(data, sumBz[:]...) + data = append(data, countBz[:]...) return } @@ -114,11 +125,20 @@ func checkPrefix(data, prefix []byte) { } // parseSum parses the sum from the encoded node data -func parseSum(data []byte) uint64 { - sum := uint64(0) - sumBz := data[len(data)-sumSizeBytes:] +func parseSumAndCount(data []byte) (sum, count uint64) { + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + + sumBz := data[firstSumByteIdx:firstCountByteIdx] if !bytes.Equal(sumBz, defaultEmptySum[:]) { + // TODO_CONSIDERATION: We chose BigEndian for readability but most computers + // now are optimized for LittleEndian encoding could be a micro optimization one day.` sum = binary.BigEndian.Uint64(sumBz) } - return sum + + countBz := data[firstCountByteIdx:] + if !bytes.Equal(countBz, defaultEmptyCount[:]) { + count = binary.BigEndian.Uint64(countBz) + } + + return } diff --git a/proofs.go b/proofs.go index b834468..7b7a762 100644 --- a/proofs.go +++ b/proofs.go @@ -202,13 +202,15 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error { // GetValueHash returns the value hash of the closest proof. func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { - if proof.ClosestValueHash == nil { + data := proof.ClosestValueHash + if data == nil { return nil } if spec.sumTrie { - return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] + firstSumByteIdx, _ := getFirstMetaByteIdx(data) + return data[:firstSumByteIdx] } - return proof.ClosestValueHash + return data } func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { @@ -323,22 +325,30 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp } // VerifySumProof verifies a Merkle proof for a sum trie. -func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { +func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum, count uint64, spec *TrieSpec) (bool, error) { var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) + + var countBz [countSizeBytes]byte + binary.BigEndian.PutUint64(countBz[:], count) + valueHash := spec.valueHash(value) valueHash = append(valueHash, sumBz[:]...) + valueHash = append(valueHash, countBz[:]...) if bytes.Equal(value, defaultEmptyValue) && sum == 0 { valueHash = defaultEmptyValue } + smtSpec := &TrieSpec{ th: spec.th, ph: spec.ph, vh: spec.vh, sumTrie: spec.sumTrie, } + nvh := WithValueHasher(nil) nvh(smtSpec) + return VerifyProof(proof, root, key, valueHash, smtSpec) } @@ -348,26 +358,37 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie if err := proof.validateBasic(spec); err != nil { return false, errors.Join(ErrBadProof, err) } + // Create a new TrieSpec with a nil path hasher. - // Since the ClosestProof already contains a hashed path, double hashing it - // will invalidate the proof. + // Since the ClosestProof already contains a hashed path, double hashing it will invalidate the proof. nilSpec := &TrieSpec{ th: spec.th, ph: newNilPathHasher(spec.ph.PathSize()), vh: spec.vh, sumTrie: spec.sumTrie, } + + // Verify the closest proof for a basic SMT if !nilSpec.sumTrie { return VerifyProof(proof.ClosestProof, root, proof.ClosestPath, proof.ClosestValueHash, nilSpec) } + + // TODO_DOCUMENT: Understand and explain (in comments) why this case is needed if proof.ClosestValueHash == nil { - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, 0, nilSpec) } - sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBytes:] + + data := proof.ClosestValueHash + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + + sumBz := data[firstSumByteIdx:firstCountByteIdx] sum := binary.BigEndian.Uint64(sumBz) - valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] - return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) + countBz := data[firstCountByteIdx:] + count := binary.BigEndian.Uint64(countBz) + + valueHash := data[:firstSumByteIdx] + return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, count, nilSpec) } // verifyProofWithUpdates @@ -445,14 +466,14 @@ func VerifyCompactSumProof( proof *SparseCompactMerkleProof, root []byte, key, value []byte, - sum uint64, + sum, count uint64, spec *TrieSpec, ) (bool, error) { decompactedProof, err := DecompactProof(proof, spec) if err != nil { return false, errors.Join(ErrBadProof, err) } - return VerifySumProof(decompactedProof, root, key, value, sum, spec) + return VerifySumProof(decompactedProof, root, key, value, sum, count, spec) } // VerifyCompactClosestProof is similar to VerifyClosestProof but for a compacted merkle proof diff --git a/proofs_test.go b/proofs_test.go index b673b2f..c91d341 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -177,9 +177,12 @@ func randomizeProof(proof *SparseMerkleProof) *SparseMerkleProof { func randomizeSumProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { - sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBytes) + data := proof.SideNodes[i] + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + sideNodes[i] = make([]byte, len(data)-sumSizeBytes-countSizeBytes) rand.Read(sideNodes[i]) // nolint: errcheck - sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBytes:]...) + sideNodes[i] = append(sideNodes[i], data[firstSumByteIdx:firstCountByteIdx]...) + sideNodes[i] = append(sideNodes[i], data[firstCountByteIdx:]...) } return &SparseMerkleProof{ SideNodes: sideNodes, diff --git a/root_test.go b/root_test.go index 08f8b6d..fb89236 100644 --- a/root_test.go +++ b/root_test.go @@ -60,6 +60,7 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { } require.NotNil(t, trie.Sum()) require.EqualValues(t, 45, trie.Sum()) + require.EqualValues(t, 10, trie.Count()) return } diff --git a/smst.go b/smst.go index a2bc49c..89b0c13 100644 --- a/smst.go +++ b/smst.go @@ -9,8 +9,19 @@ import ( ) const ( - // The number of bits used to represent the sum of a node + // The number of bytes used to represent the sum of a node sumSizeBytes = 8 + + // The number of bytes used to track the count of non-empty nodes in the trie. + // + // TODO_TECHDEBT: Since we are using sha256, we could theoretically have + // 2^256 leaves. This would require 32 bytes, and would not fit in a uint64. + // For now, we are assuming that we will not have more than 2^64 - 1 leaves. + // + // This need for this variable could be removed, but is kept around to enable + // a simpler transition to little endian encoding if/when necessary. + // Ref: https://github.com/pokt-network/smt/pull/46#discussion_r1636975124 + countSizeBytes = 8 ) var _ SparseMerkleSumTrie = (*SMST)(nil) @@ -71,27 +82,38 @@ func (smst *SMST) Spec() *TrieSpec { return &smst.TrieSpec } -// Get retrieves the value digest for the given key and the digest of the value -// along with its weight provided a leaf node exists. +// Get retrieves the value digest for the given key, along with its weight assuming +// the node exists, otherwise the default placeholder values are returned func (smst *SMST) Get(key []byte) (valueDigest []byte, weight uint64, err error) { // Retrieve the value digest from the trie for the given key - valueDigest, err = smst.SMT.Get(key) + value, err := smst.SMT.Get(key) if err != nil { return nil, 0, err } - // Check if it ias an empty branch - if bytes.Equal(valueDigest, defaultEmptyValue) { + // Check if it is an empty branch + if bytes.Equal(value, defaultEmptyValue) { return defaultEmptyValue, 0, nil } + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value) + + // Extract the value digest only + valueDigest = value[:firstSumByteIdx] + // Retrieve the node weight var weightBz [sumSizeBytes]byte - copy(weightBz[:], valueDigest[len(valueDigest)-sumSizeBytes:]) + copy(weightBz[:], value[firstSumByteIdx:firstCountByteIdx]) weight = binary.BigEndian.Uint64(weightBz[:]) - // Remove the weight from the value digest - valueDigest = valueDigest[:len(valueDigest)-sumSizeBytes] + // Retrieve the number of non-empty nodes in the sub trie + var countBz [countSizeBytes]byte + copy(countBz[:], value[firstCountByteIdx:]) + count := binary.BigEndian.Uint64(countBz[:]) + + if count != 1 { + panic("count for leaf node should always be 1") + } // Return the value digest and weight return valueDigest, weight, nil @@ -109,9 +131,14 @@ func (smst *SMST) Update(key, value []byte, weight uint64) error { var weightBz [sumSizeBytes]byte binary.BigEndian.PutUint64(weightBz[:], weight) + // Convert the node count (1 for a single leaf) to a byte slice + var countBz [countSizeBytes]byte + binary.BigEndian.PutUint64(countBz[:], 1) + // Compute the digest of the value and append the weight to it valueDigest := smst.valueHash(value) valueDigest = append(valueDigest, weightBz[:]...) + valueDigest = append(valueDigest, countBz[:]...) // Return the result of the trie update return smst.SMT.Update(key, valueDigest) @@ -150,11 +177,39 @@ func (smst *SMST) Root() MerkleRoot { // Sum returns the sum of the entire trie stored in the root. // If the tree is not a sum tree, it will panic. func (smst *SMST) Sum() uint64 { - rootDigest := smst.Root() + rootDigest := []byte(smst.Root()) + if !smst.Spec().sumTrie { panic("SMST: not a merkle sum trie") } + + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest) + var sumBz [sumSizeBytes]byte - copy(sumBz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBytes:]) + copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx]) return binary.BigEndian.Uint64(sumBz[:]) } + +// Count returns the number of non-empty nodes in the entire trie stored in the root. +func (smst *SMST) Count() uint64 { + rootDigest := []byte(smst.Root()) + + if !smst.Spec().sumTrie { + panic("SMST: not a merkle sum trie") + } + + _, firstCountByteIdx := getFirstMetaByteIdx(rootDigest) + + var countBz [countSizeBytes]byte + copy(countBz[:], rootDigest[firstCountByteIdx:]) + return binary.BigEndian.Uint64(countBz[:]) +} + +// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte +// in the data slice provided. This is useful metadata when parsing the data +// of any node in the trie. +func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) { + firstCountByteIdx = len(data) - countSizeBytes + firstSumByteIdx = firstCountByteIdx - sumSizeBytes + return +} diff --git a/smst_example_test.go b/smst_example_test.go index f1cdbca..c5d5ac6 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -2,13 +2,16 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" + + "github.com/stretchr/testify/require" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMST() { +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMST(t *testing.T) { // Initialise a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() @@ -36,13 +39,21 @@ func ExampleSMST() { root := trie.Root() // Verify the Merkle proof for "foo"="oof" where "foo" has a sum of 10 - valid_true1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + valid_true1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, 1, trie.Spec()) + require.True(t, valid_true1) // Verify the Merkle proof for "baz"="zab" where "baz" has a sum of 7 - valid_true2, _ := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + valid_true2, _ := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, 1, trie.Spec()) + require.True(t, valid_true2) // Verify the Merkle proof for "bin"="nib" where "bin" has a sum of 3 - valid_true3, _ := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + valid_true3, _ := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, 1, trie.Spec()) + require.True(t, valid_true3) // Fail to verify the Merkle proof for "foo"="oof" where "foo" has a sum of 11 - valid_false1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) - fmt.Println(valid_true1, valid_true2, valid_true3, valid_false1) - // Output: true true true false + valid_false1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, 1, trie.Spec()) + require.False(t, valid_false1) + + // Verify the total sum of the trie + require.EqualValues(t, 20, trie.Sum()) + + // Verify the number of non-empty leafs in the trie + require.EqualValues(t, 3, trie.Count()) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index d0d8c9d..f33bd4e 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -33,58 +33,78 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, base.placeholder(), []byte("testKey3"), defaultEmptyValue, 0, base) + result, err = VerifySumProof(proof, base.placeholder(), []byte("testKey3"), defaultEmptyValue, 0, 0, base) require.NoError(t, err) require.True(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) + result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, 1, base) require.NoError(t, err) require.False(t, result) // Add a key, generate and verify a Merkle proof. err = smst.Update([]byte("testKey"), []byte("testValue"), 5) require.NoError(t, err) + root = smst.Root() proof, err = smst.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, 1, base) // valid require.NoError(t, err) require.True(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, 2, base) // wrong count require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, 1, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, 1, base) // wrong sum + require.NoError(t, err) + require.False(t, result) + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, 1, base) // wrong value and sum require.NoError(t, err) require.False(t, result) // Add a key, generate and verify both Merkle proofs. err = smst.Update([]byte("testKey2"), []byte("testValue"), 5) require.NoError(t, err) + root = smst.Root() proof, err = smst.Prove([]byte("testKey")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, base) // valid + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, 1, base) // valid require.NoError(t, err) require.True(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, base) // wrong value + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 5, 2, base) // wrong count require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, base) // wrong sum + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 5, 1, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, base) // wrong value and sum + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("testValue"), 10, 1, base) // wrong sum + require.NoError(t, err) + require.False(t, result) + + result, err = VerifySumProof(proof, root, []byte("testKey"), []byte("badValue"), 10, 1, base) // wrong value and sum require.NoError(t, err) require.False(t, result) + result, err = VerifySumProof( randomizeSumProof(proof), root, []byte("testKey"), []byte("testValue"), 5, + 1, base, ) // invalid proof require.NoError(t, err) @@ -93,16 +113,16 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey2")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 5, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 5, 1, base) // valid require.NoError(t, err) require.True(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 5, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 5, 1, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 10, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("testValue"), 10, 1, base) // wrong sum require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, base) // wrong value and sum + result, err = VerifySumProof(proof, root, []byte("testKey2"), []byte("badValue"), 10, 1, base) // wrong value and sum require.NoError(t, err) require.False(t, result) result, err = VerifySumProof( @@ -111,22 +131,23 @@ func TestSMST_Proof_Operations(t *testing.T) { []byte("testKey2"), []byte("testValue"), 5, + 1, base, ) // invalid proof require.NoError(t, err) require.False(t, result) - // Try proving a default value for a non-default leaf. + // Try (and fail) proving a default value for a non-default leaf. var sum [sumSizeBytes]byte binary.BigEndian.PutUint64(sum[:], 5) - tval := base.valueHash([]byte("testValue")) - tval = append(tval, sum[:]...) - _, leafData := base.th.digestSumLeafNode(base.ph.Path([]byte("testKey2")), tval) + testVal := base.valueHash([]byte("testValue")) + testVal = append(testVal, sum[:]...) + _, leafData := base.th.digestSumLeafNode(base.ph.Path([]byte("testKey2")), testVal) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultEmptyValue, 0, base) + result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultEmptyValue, 0, 0, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) @@ -134,13 +155,13 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 0, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 0, 0, base) // valid require.NoError(t, err) require.True(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, base) // wrong value + result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, 0, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 5, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 5, 0, base) // wrong sum require.NoError(t, err) require.False(t, result) result, err = VerifySumProof( @@ -149,6 +170,7 @@ func TestSMST_Proof_Operations(t *testing.T) { []byte("testKey3"), defaultEmptyValue, 0, + 0, base, ) // invalid proof require.NoError(t, err) @@ -180,7 +202,7 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { } proof.SideNodes = sideNodes require.EqualError(t, proof.validateBasic(base), "too many side nodes: got 257 but max is 256") - result, err := VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + result, err := VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, 1, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) @@ -190,17 +212,17 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { proof, _ = smst.Prove([]byte("testKey1")) proof.NonMembershipLeafData = make([]byte, 1) require.EqualError(t, proof.validateBasic(base), "invalid non-membership leaf data size: got 1 but min is 33") - result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, 1, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) require.Error(t, err) - // Case: unexpected sidenode size. + // Case: unexpected side node size. proof, _ = smst.Prove([]byte("testKey1")) proof.SideNodes[0] = make([]byte, 1) - require.EqualError(t, proof.validateBasic(base), "invalid side node size: got 1 but want 40") - result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + require.EqualError(t, proof.validateBasic(base), "invalid side node size: got 1 but want 48") + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, 1, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) @@ -212,10 +234,10 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { require.EqualError( t, proof.validateBasic(base), - "invalid sibling data hash: got 437437455c0f5ca33597b9dd2a307bdfcc6833d3c272e101f30ed6358783fc247f0b9966865746c1 but want 1dc9a3da748c53b22c9e54dcafe9e872341babda9b3e50577f0b9966865746c10000000000000009", + "invalid sibling data hash: got 30ecfc36781633f6765088e69165733d7192483b4468ca53ef21794bb035f72f799a6026c448b3eacafa2b82ca1ff7f2 but want 3c54b08cc0074a44a12cb0ea0486d29d799a6026c448b3eacafa2b82ca1ff7f200000000000000090000000000000003", ) - result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, base) + result, err = VerifySumProof(proof, root, []byte("testKey1"), []byte("testValue1"), 1, 1, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) _, err = CompactProof(proof, base) @@ -302,6 +324,7 @@ func TestSMST_ProveClosest(t *testing.T) { var root []byte var err error var sumBz [sumSizeBytes]byte + var countBz [countSizeBytes]byte smn = simplemap.NewSimpleMap() require.NoError(t, err) @@ -337,6 +360,8 @@ func TestSMST_ProveClosest(t *testing.T) { closestValueHash := []byte("testValue2") binary.BigEndian.PutUint64(sumBz[:], 24) closestValueHash = append(closestValueHash, sumBz[:]...) + binary.BigEndian.PutUint64(countBz[:], 1) + closestValueHash = append(closestValueHash, countBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path[:], FlippedBits: []int{3, 6}, @@ -363,6 +388,7 @@ func TestSMST_ProveClosest(t *testing.T) { closestValueHash = []byte("testValue4") binary.BigEndian.PutUint64(sumBz[:], 30) closestValueHash = append(closestValueHash, sumBz[:]...) + closestValueHash = append(closestValueHash, countBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path2[:], FlippedBits: []int{3}, @@ -427,9 +453,17 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") + + // Manually insert the sum, which is the weight of the single node in the trie var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], 5) closestValueHash = append(closestValueHash, sumBz[:]...) + + // Manually insert the count, which is 1 for a single leaf + var countBz [countSizeBytes]byte + binary.BigEndian.PutUint64(countBz[:], 1) + closestValueHash = append(closestValueHash, countBz[:]...) + require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path[:], FlippedBits: []int{}, diff --git a/smst_test.go b/smst_test.go index e331994..bd52cd4 100644 --- a/smst_test.go +++ b/smst_test.go @@ -359,6 +359,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Update([]byte("testKey"), []byte("testValue"), 5) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) // only root node + require.Equal(t, uint64(1), impl.Count()) } t.Run("delete 1", func(t *testing.T) { @@ -366,6 +367,7 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Delete([]byte("testKey")) require.NoError(t, err) require.Equal(t, 0, nodeCount(t)) + require.Equal(t, uint64(0), impl.Count()) }) t.Run("overwrite 1", func(t *testing.T) { @@ -373,46 +375,76 @@ func TestSMST_OrphanRemoval(t *testing.T) { err = smst.Update([]byte("testKey"), []byte("testValue2"), 10) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) + require.Equal(t, uint64(1), impl.Count()) }) - type testCase struct { - keys []string - count int - } - // sha256(testKey) = 0001... - // sha256(testKey2) = 1000... common prefix len 0; 3 nodes (root + 2 leaf) - // sha256(foo) = 0010... common prefix len 2; 5 nodes (3 inner + 2 leaf) - cases := []testCase{ - {[]string{"testKey2"}, 3}, - {[]string{"foo"}, 4}, - {[]string{"testKey2", "foo"}, 6}, - {[]string{"a", "b", "c", "d", "e"}, 14}, - } - t.Run("overwrite and delete", func(t *testing.T) { setup() err = smst.Update([]byte("testKey"), []byte("testValue2"), 2) require.NoError(t, err) require.Equal(t, 1, nodeCount(t)) + require.Equal(t, uint64(1), impl.Count()) err = smst.Delete([]byte("testKey")) require.NoError(t, err) require.Equal(t, 0, nodeCount(t)) + require.Equal(t, uint64(0), impl.Count()) + }) - for tci, tc := range cases { + type testCase struct { + desc string + keys []string + expectedNodeCount int + expectedLeafCount int + } + // sha256(testKey) = 0001... + // sha256(testKey2) = 1000... common prefix len 0; 3 nodes (root + 2 leaf) + // sha256(foo) = 0010... common prefix len 2; 5 nodes (3 inner + 2 leaf) + cases := []testCase{ + { + desc: "insert a single key (testKey2) which DOES NOT HAVE a similar prefix to the already present key (testKey)", + keys: []string{"testKey2"}, + expectedNodeCount: 3, + expectedLeafCount: 2, + }, + { + desc: "insert a single key (foo) which DOES HAVE a similar prefix to the already present key (testKey)", + keys: []string{"foo"}, + expectedNodeCount: 4, + expectedLeafCount: 2, + }, + { + desc: "override two existing keys which were added by the test cases above", + keys: []string{"testKey2", "foo"}, + expectedNodeCount: 6, + expectedLeafCount: 3, + }, + { + desc: "add 5 new keys with no common prefix to the existing keys", + keys: []string{"a", "b", "c", "d", "e"}, + expectedNodeCount: 14, + expectedLeafCount: 6, + }, + } + for tci, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + // Inserts a key-value pair for `testKey` setup() + // Insert value for every key for _, key := range tc.keys { err = smst.Update([]byte(key), []byte("testValue2"), 10) require.NoError(t, err, tci) } - require.Equal(t, tc.count, nodeCount(t), tci) + require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci) + require.Equal(t, uint64(tc.expectedLeafCount), impl.Count()) - // Overwrite doesn't change count + // Overwrite doesn't change node or leaf count for _, key := range tc.keys { err = smst.Update([]byte(key), []byte("testValue3"), 10) require.NoError(t, err, tci) } - require.Equal(t, tc.count, nodeCount(t), tci) + require.Equal(t, tc.expectedNodeCount, nodeCount(t), tci) + require.Equal(t, uint64(tc.expectedLeafCount), impl.Count()) // Deletion removes all nodes except root for _, key := range tc.keys { @@ -420,13 +452,15 @@ func TestSMST_OrphanRemoval(t *testing.T) { require.NoError(t, err, tci) } require.Equal(t, 1, nodeCount(t), tci) + require.Equal(t, uint64(1), impl.Count()) // Deleting and re-inserting a persisted node doesn't change count require.NoError(t, smst.Delete([]byte("testKey"))) require.NoError(t, smst.Update([]byte("testKey"), []byte("testValue"), 10)) require.Equal(t, 1, nodeCount(t), tci) - } - }) + require.Equal(t, uint64(1), impl.Count()) + }) + } } func TestSMST_TotalSum(t *testing.T) { @@ -441,20 +475,31 @@ func TestSMST_TotalSum(t *testing.T) { // Check root hash contains the correct hex sum root1 := smst.Root() - sumBz := root1[len(root1)-sumSizeBytes:] + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root1) + + // Get the sum from the root hash + sumBz := root1[firstSumByteIdx:firstCountByteIdx] rootSum := binary.BigEndian.Uint64(sumBz) - require.NoError(t, err) - // Calculate total sum of the trie + // Get the count from the root hash + countBz := root1[firstCountByteIdx:] + rootCount := binary.BigEndian.Uint64(countBz) + + // Retrieve and compare the sum sum := smst.Sum() require.Equal(t, sum, uint64(15)) require.Equal(t, sum, rootSum) + // Retrieve and compare the count + count := smst.Count() + require.Equal(t, count, uint64(3)) + require.Equal(t, count, rootCount) + // Prove inclusion proof, err := smst.Prove([]byte("key1")) require.NoError(t, err) checkCompactEquivalence(t, proof, smst.Spec()) - valid, err := VerifySumProof(proof, root1, []byte("key1"), []byte("value1"), 5, smst.Spec()) + valid, err := VerifySumProof(proof, root1, []byte("key1"), []byte("value1"), 5, 1, smst.Spec()) require.NoError(t, err) require.True(t, valid) @@ -464,6 +509,10 @@ func TestSMST_TotalSum(t *testing.T) { sum = smst.Sum() require.Equal(t, sum, uint64(10)) + // Check that the count is correct after deleting a key + count = smst.Count() + require.Equal(t, count, uint64(2)) + // Check that the sum is correct after importing the trie require.NoError(t, smst.Commit()) root2 := smst.Root() @@ -471,6 +520,10 @@ func TestSMST_TotalSum(t *testing.T) { sum = smst.Sum() require.Equal(t, sum, uint64(10)) + // Check that the count is correct after importing the trie + count = smst.Count() + require.Equal(t, count, uint64(2)) + // Calculate the total sum of a larger trie snm = simplemap.NewSimpleMap() smst = NewSparseMerkleSumTrie(snm, sha256.New()) @@ -481,6 +534,10 @@ func TestSMST_TotalSum(t *testing.T) { require.NoError(t, smst.Commit()) sum = smst.Sum() require.Equal(t, sum, uint64(49995000)) + + // Check that the count is correct after building a larger trie + count = smst.Count() + require.Equal(t, count, uint64(9999)) } func TestSMST_Retrieval(t *testing.T) { @@ -549,4 +606,7 @@ func TestSMST_Retrieval(t *testing.T) { sum = lazy.Sum() require.Equal(t, sum, uint64(15)) + + count := lazy.Count() + require.Equal(t, count, uint64(3)) } diff --git a/smst_utils_test.go b/smst_utils_test.go index db2acd0..03920d9 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -18,18 +18,24 @@ type SMSTWithStorage struct { preimages kvstore.MapStore } -// Update updates a key with a new value in the trie and adds the value to the -// preimages KVStore -// Preimages are the values prior to them being hashed - they are used to -// confirm the values are in the trie +// Update a key with a new value in the trie and add it to the preimages KVStore. +// Preimages are the values prior to being hashed, used to confirm the values are in the trie. func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { if err := smst.SMST.Update(key, value, sum); err != nil { return err } valueHash := smst.valueHash(value) + + // Append the sum to the value before storing it var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) + + // Append the count to the value before storing it + var countBz [countSizeBytes]byte + binary.BigEndian.PutUint64(countBz[:], 1) + value = append(value, countBz[:]...) + return smst.preimages.Set(valueHash, value) } @@ -48,6 +54,7 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { if valueHash == nil { return nil, 0, nil } + // Extract the value from the preimages KVStore value, err := smst.preimages.Get(valueHash) if err != nil { if errors.Is(err, ErrKeyNotFound) { @@ -57,13 +64,18 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { // Otherwise percolate up any other error return nil, 0, err } + + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(value) + + // Extract the sum from the value var sumBz [sumSizeBytes]byte - copy(sumBz[:], value[len(value)-sumSizeBytes:]) + copy(sumBz[:], value[firstSumByteIdx:firstCountByteIdx]) storedSum := binary.BigEndian.Uint64(sumBz[:]) if storedSum != sum { return nil, 0, fmt.Errorf("sum mismatch for %s: got %d, expected %d", string(key), storedSum, sum) } - return value[:len(value)-sumSizeBytes], storedSum, nil + + return value[:firstSumByteIdx], storedSum, nil } // Has returns true if the value at the given key is non-default, false otherwise. diff --git a/smt.go b/smt.go index 120186f..73b9c37 100644 --- a/smt.go +++ b/smt.go @@ -632,7 +632,7 @@ func (smt *SMT) parseSumTrieNode(data, digest []byte) (trieNode, error) { digest: digest, }, nil } else if isExtNode(data) { - pathBounds, path, childData, _ := smt.parseSumExtNode(data) + pathBounds, path, childData, _, _ := smt.parseSumExtNode(data) return &extensionNode{ path: path, pathBounds: [2]byte(pathBounds), @@ -641,7 +641,7 @@ func (smt *SMT) parseSumTrieNode(data, digest []byte) (trieNode, error) { digest: digest, }, nil } else if isInnerNode(data) { - leftData, rightData, _ := smt.th.parseSumInnerNode(data) + leftData, rightData, _, _ := smt.th.parseSumInnerNode(data) return &innerNode{ leftChild: &lazyNode{leftData}, rightChild: &lazyNode{rightData}, diff --git a/smt_example_test.go b/smt_example_test.go index 2f7af1b..5f831d0 100644 --- a/smt_example_test.go +++ b/smt_example_test.go @@ -4,6 +4,8 @@ import ( "crypto/sha256" "testing" + "github.com/stretchr/testify/require" + "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) @@ -29,9 +31,8 @@ func TestExampleSMT(t *testing.T) { // Verify the Merkle proof for "foo"="bar" valid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("bar"), trie.Spec()) + require.True(t, valid) // Attempt to verify the Merkle proof for "foo"="baz" invalid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("baz"), trie.Spec()) - - // Output: true false - t.Log(valid, invalid) + require.False(t, invalid) } diff --git a/trie_spec.go b/trie_spec.go index a9f047e..304e8fb 100644 --- a/trie_spec.go +++ b/trie_spec.go @@ -33,6 +33,7 @@ func (spec *TrieSpec) placeholder() []byte { if spec.sumTrie { placeholder := spec.th.placeholder() placeholder = append(placeholder, defaultEmptySum[:]...) + placeholder = append(placeholder, defaultEmptyCount[:]...) return placeholder } return spec.th.placeholder() @@ -41,7 +42,7 @@ func (spec *TrieSpec) placeholder() []byte { // hashSize returns the hash size depending on the trie type func (spec *TrieSpec) hashSize() int { if spec.sumTrie { - return spec.th.hashSize() + sumSizeBytes + return spec.th.hashSize() + sumSizeBytes + countSizeBytes } return spec.th.hashSize() } @@ -100,13 +101,17 @@ func (spec *TrieSpec) hashSerialization(data []byte) []byte { // Used for verification of serialized proof data for sum trie nodes func (spec *TrieSpec) hashSumSerialization(data []byte) []byte { if isExtNode(data) { - pathBounds, path, childHash, _ := spec.parseSumExtNode(data) + pathBounds, path, childHash, _, _ := spec.parseSumExtNode(data) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) return spec.digestSumNode(&ext) } + + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + digest := spec.th.digestData(data) - digest = append(digest, data[len(data)-sumSizeBytes:]...) + digest = append(digest, data[firstSumByteIdx:firstCountByteIdx]...) + digest = append(digest, data[firstCountByteIdx:]...) return digest } @@ -188,7 +193,7 @@ func (spec *TrieSpec) encodeSumNode(node trieNode) (preImage []byte) { return nil } -// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum] +// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum]+[8 byte count] func (spec *TrieSpec) digestSumNode(node trieNode) []byte { if node == nil { return spec.placeholder() @@ -209,8 +214,10 @@ func (spec *TrieSpec) digestSumNode(node trieNode) []byte { } if *cache == nil { preImage := spec.encodeSumNode(node) + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(preImage) *cache = spec.th.digestData(preImage) - *cache = append(*cache, preImage[len(preImage)-sumSizeBytes:]...) + *cache = append(*cache, preImage[firstSumByteIdx:firstCountByteIdx]...) + *cache = append(*cache, preImage[firstCountByteIdx:]...) } return *cache } @@ -239,34 +246,51 @@ func (spec *TrieSpec) parseExtNode(data []byte) (pathBounds, path, childData []b // parseSumLeafNode parses a leafNode and returns its weight as well // // nolint: unused -func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight uint64) { +func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight, count uint64) { // panics if not a leaf node checkPrefix(data, leafNodePrefix) path = data[prefixLen : prefixLen+spec.ph.PathSize()] value = data[prefixLen+spec.ph.PathSize():] + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + // Extract the sum from the encoded node data var weightBz [sumSizeBytes]byte - copy(weightBz[:], value[len(value)-sumSizeBytes:]) + copy(weightBz[:], data[firstSumByteIdx:firstCountByteIdx]) binary.BigEndian.PutUint64(weightBz[:], weight) + // Extract the count from the encoded node data + var countBz [countSizeBytes]byte + copy(countBz[:], value[firstCountByteIdx:]) + binary.BigEndian.PutUint64(countBz[:], count) + if count != 1 { + panic("count for leaf node should always be 1") + } + return } // parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data -func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData []byte, sum uint64) { +func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData []byte, sum, count uint64) { // panics if not an extension node checkPrefix(data, extNodePrefix) + firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(data) + // Extract the sum from the encoded node data var sumBz [sumSizeBytes]byte - copy(sumBz[:], data[len(data)-sumSizeBytes:]) + copy(sumBz[:], data[firstSumByteIdx:firstCountByteIdx]) binary.BigEndian.PutUint64(sumBz[:], sum) + // Extract the count from the encoded node data + var countBz [countSizeBytes]byte + copy(countBz[:], data[firstCountByteIdx:]) + binary.BigEndian.PutUint64(countBz[:], count) + // +2 represents the length of the pathBounds pathBounds = data[prefixLen : prefixLen+2] path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] - childData = data[prefixLen+2+spec.ph.PathSize() : len(data)-sumSizeBytes] + childData = data[prefixLen+2+spec.ph.PathSize() : firstSumByteIdx] return } diff --git a/types.go b/types.go index fd6dacf..0bed0a0 100644 --- a/types.go +++ b/types.go @@ -18,6 +18,8 @@ var ( defaultEmptyValue []byte // defaultEmptySum is the default sum value for a leaf node defaultEmptySum [sumSizeBytes]byte + // defaultEmptyCount is the default count value for a leaf node + defaultEmptyCount [countSizeBytes]byte ) // MerkleRoot is a type alias for a byte slice returned from the Root method @@ -64,11 +66,13 @@ type SparseMerkleSumTrie interface { // Delete deletes a value from the SMST. Raises an error if the key is not present. Delete(key []byte) error // Get descends the trie to access a value. Returns nil if key is not present. - Get(key []byte) ([]byte, uint64, error) + Get(key []byte) (data []byte, sum uint64, err error) // Root computes the Merkle root digest. Root() MerkleRoot // Sum computes the total sum of the Merkle trie Sum() uint64 + // Count returns the total number of non-empty leaves in the trie + Count() uint64 // Prove computes a Merkle proof of inclusion or exclusion of a key. Prove(key []byte) (*SparseMerkleProof, error) // ProveClosest computes a Merkle proof of inclusion for a key in the trie