From 185749638d16cce724bfcb301ffaec7339bcc7ef Mon Sep 17 00:00:00 2001 From: Daniel Olshansky Date: Fri, 9 Feb 2024 15:54:07 -0800 Subject: [PATCH] Renamed a few things and converted the examples to proper tests --- extension_node.go | 8 ++++---- proofs.go | 4 ++-- smst.go | 2 +- smst_example_test.go | 33 ++++++++++++++++++++++++--------- smst_proofs_test.go | 2 +- smst_utils_test.go | 2 +- smt.go | 2 +- smt_example_test.go | 12 +++++++----- smt_proofs_test.go | 2 +- smt_utils_test.go | 2 +- types.go | 28 ++++++++++++++++------------ utils.go | 6 +++--- 12 files changed, 62 insertions(+), 41 deletions(-) diff --git a/extension_node.go b/extension_node.go index 88335f3..f2d2e32 100644 --- a/extension_node.go +++ b/extension_node.go @@ -132,11 +132,11 @@ func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { // expand returns the inner node that represents the start of the singly // linked list that this extension node represents -func (ext *extensionNode) expand() trieNode { - last := ext.child - for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { +func (extNode *extensionNode) expand() trieNode { + last := extNode.child + for i := extNode.pathEnd() - 1; i >= extNode.pathStart(); i-- { var next innerNode - if getPathBit(ext.path, i) == leftChildBit { + if getPathBit(extNode.path, i) == leftChildBit { next.leftChild = last } else { next.rightChild = last diff --git a/proofs.go b/proofs.go index 211eed5..af4126c 100644 --- a/proofs.go +++ b/proofs.go @@ -283,7 +283,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) - valueHash := spec.digestValue(value) + valueHash := spec.valueDigest(value) valueHash = append(valueHash, sumBz[:]...) if bytes.Equal(value, defaultEmptyValue) && sum == 0 { valueHash = defaultEmptyValue @@ -348,7 +348,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v updates = append(updates, update) } } else { // Membership proof. - valueHash := spec.digestValue(value) + valueHash := spec.valueDigest(value) currentHash, currentData = digestLeaf(spec, path, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData diff --git a/smst.go b/smst.go index 394cdb3..58dc58b 100644 --- a/smst.go +++ b/smst.go @@ -84,7 +84,7 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { // appended with the binary representation of the weight provided. The weight // is used to compute the interim and total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { - valueHash := smst.digestValue(value) + valueHash := smst.valueDigest(value) var weightBz [sumSizeBits]byte binary.BigEndian.PutUint64(weightBz[:], weight) valueHash = append(valueHash, weightBz[:]...) diff --git a/smst_example_test.go b/smst_example_test.go index f1cdbca..79d1772 100644 --- a/smst_example_test.go +++ b/smst_example_test.go @@ -2,18 +2,21 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" + + "github.com/stretchr/testify/require" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMST() { - // Initialise a new in-memory key-value store to store the nodes of the trie +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMST(t *testing.T) { + // Initialize a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialise the trie + // Initialize the trie trie := smt.NewSparseMerkleSumTrie(nodeStore, sha256.New()) // Update trie with keys, values and their sums @@ -36,13 +39,25 @@ func ExampleSMST() { root := trie.Root() // Verify the Merkle proof for "foo"="oof" where "foo" has a sum of 10 - valid_true1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + valid_true1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 10, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true1) + // Verify the Merkle proof for "baz"="zab" where "baz" has a sum of 7 - valid_true2, _ := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + valid_true2, err := smt.VerifySumProof(proof2, root, []byte("baz"), []byte("zab"), 7, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true2) + // Verify the Merkle proof for "bin"="nib" where "bin" has a sum of 3 - valid_true3, _ := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + valid_true3, err := smt.VerifySumProof(proof3, root, []byte("bin"), []byte("nib"), 3, trie.Spec()) + require.NoError(t, err) + require.True(t, valid_true3) + // Fail to verify the Merkle proof for "foo"="oof" where "foo" has a sum of 11 - valid_false1, _ := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) - fmt.Println(valid_true1, valid_true2, valid_true3, valid_false1) + valid_false1, err := smt.VerifySumProof(proof1, root, []byte("foo"), []byte("oof"), 11, trie.Spec()) + require.NoError(t, err) + require.False(t, valid_false1) + // Output: true true true false + t.Log(valid_true1, valid_true2, valid_true3, valid_false1) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index 394d595..41ad48d 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -105,7 +105,7 @@ func TestSMST_Proof_Operations(t *testing.T) { // Try proving a default value for a non-default leaf. var sum [sumSizeBits]byte binary.BigEndian.PutUint64(sum[:], 5) - tval := base.digestValue([]byte("testValue")) + tval := base.valueDigest([]byte("testValue")) tval = append(tval, sum[:]...) _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) proof = &SparseMerkleProof{ diff --git a/smst_utils_test.go b/smst_utils_test.go index 90e27d9..4526fcc 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -26,7 +26,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { if err := smst.SMST.Update(key, value, sum); err != nil { return err } - valueHash := smst.digestValue(value) + valueHash := smst.valueDigest(value) var sumBz [sumSizeBits]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) diff --git a/smt.go b/smt.go index 18e075f..de20e07 100644 --- a/smt.go +++ b/smt.go @@ -108,7 +108,7 @@ func (smt *SMT) Get(key []byte) ([]byte, error) { func (smt *SMT) Update(key []byte, value []byte) error { // Expand path := smt.ph.Path(key) - valueHash := smt.digestValue(value) + valueHash := smt.valueDigest(value) var orphans orphanNodes trie, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { diff --git a/smt_example_test.go b/smt_example_test.go index 6d74980..2f7af1b 100644 --- a/smt_example_test.go +++ b/smt_example_test.go @@ -2,18 +2,19 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMT() { - // Initialise a new in-memory key-value store to store the nodes of the trie +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMT(t *testing.T) { + // Initialize a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialise the trie + // Initialize the trie trie := smt.NewSparseMerkleTrie(nodeStore, sha256.New()) // Update the key "foo" with the value "bar" @@ -30,6 +31,7 @@ func ExampleSMT() { valid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("bar"), trie.Spec()) // Attempt to verify the Merkle proof for "foo"="baz" invalid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("baz"), trie.Spec()) - fmt.Println(valid, invalid) + // Output: true false + t.Log(valid, invalid) } diff --git a/smt_proofs_test.go b/smt_proofs_test.go index b1d5005..2364295 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -84,7 +84,7 @@ func TestSMT_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.digestValue([]byte("testValue"))) + _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.valueDigest([]byte("testValue"))) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, diff --git a/smt_utils_test.go b/smt_utils_test.go index cff400c..7137fa8 100644 --- a/smt_utils_test.go +++ b/smt_utils_test.go @@ -24,7 +24,7 @@ func (smt *SMTWithStorage) Update(key, value []byte) error { if err := smt.SMT.Update(key, value); err != nil { return err } - valueHash := smt.digestValue(value) + valueHash := smt.valueDigest(value) return smt.preimages.Set(valueHash, value) } diff --git a/types.go b/types.go index ab54818..8d438b2 100644 --- a/types.go +++ b/types.go @@ -108,35 +108,39 @@ func (spec *TrieSpec) Spec() *TrieSpec { // depth returns the maximum depth of the trie. // Since this tree is a binary tree, the depth is the number of bits in the path +// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits } -func (spec *TrieSpec) digestValue(data []byte) []byte { +// valueDigest returns the hash of a value, or the value itself if no value hasher is specified. +func (spec *TrieSpec) valueDigest(value []byte) []byte { if spec.vh == nil { - return data + return value } - return spec.vh.HashValue(data) + return spec.vh.HashValue(value) } -func (spec *TrieSpec) serialize(node trieNode) (data []byte) { +// encodeNode serializes a node into a byte slice +func (spec *TrieSpec) encodeNode(node trieNode) (data []byte) { switch n := node.(type) { case *lazyNode: - panic("serialize(lazyNode)") + panic("Encoding a lazyNode is not supported") case *leafNode: return encodeLeafNode(n.path, n.valueHash) case *innerNode: - lchild := spec.hashNode(n.leftChild) - rchild := spec.hashNode(n.rightChild) - return encodeInnerNode(lchild, rchild) + leftChild := spec.digestNode(n.leftChild) + rightChild := spec.digestNode(n.rightChild) + return encodeInnerNode(leftChild, rightChild) case *extensionNode: - child := spec.hashNode(n.child) + child := spec.digestNode(n.child) return encodeExtensionNode(n.pathBounds, n.path, child) } return nil } -func (spec *TrieSpec) hashNode(node trieNode) []byte { +// digestNode hashes a node returning its digest +func (spec *TrieSpec) digestNode(node trieNode) []byte { if node == nil { return spec.th.placeholder() } @@ -150,12 +154,12 @@ func (spec *TrieSpec) hashNode(node trieNode) []byte { cache = &n.digest case *extensionNode: if n.digest == nil { - n.digest = spec.hashNode(n.expand()) + n.digest = spec.digestNode(n.expand()) } return n.digest } if *cache == nil { - *cache = spec.th.digest(spec.serialize(node)) + *cache = spec.th.digest(spec.encodeNode(node)) } return *cache } diff --git a/utils.go b/utils.go index cda5ae2..801284f 100644 --- a/utils.go +++ b/utils.go @@ -157,7 +157,7 @@ func hashNode(spec *TrieSpec, node trieNode) []byte { if spec.sumTrie { return spec.hashSumNode(node) } - return spec.hashNode(node) + return spec.digestNode(node) } // serialize serializes a node depending on the trie type @@ -165,7 +165,7 @@ func serialize(spec *TrieSpec, node trieNode) []byte { if spec.sumTrie { return spec.sumSerialize(node) } - return spec.serialize(node) + return spec.encodeNode(node) } // hashPreimage hashes the serialised data provided depending on the trie type @@ -182,7 +182,7 @@ func hashSerialization(smt *TrieSpec, data []byte) []byte { pathBounds, path, childHash := parseExtension(data, smt.ph) ext := extensionNode{path: path, child: &lazyNode{childHash}} copy(ext.pathBounds[:], pathBounds) - return smt.hashNode(&ext) + return smt.digestNode(&ext) } return smt.th.digest(data) }