Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Health] refactor: SMST#Root(), #Sum(), & #Count() #51

Merged
merged 10 commits into from
Jul 17, 2024
89 changes: 65 additions & 24 deletions root.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@ import (
"fmt"
)

const (
// These are intentionally exposed to allow for for testing and custom
// implementations of downstream applications.
SmtRootSizeBytes = 32
SmstRootSizeBytes = SmtRootSizeBytes + sumSizeBytes + countSizeBytes
)

// MustSum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will panic.
func (r MerkleRoot) MustSum() uint64 {
sum, err := r.Sum()
func (root MerkleSumRoot) MustSum() uint64 {
sum, err := root.Sum()
if err != nil {
panic(err)
}
Expand All @@ -27,28 +20,76 @@ func (r MerkleRoot) MustSum() uint64 {
// Sum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will return an error.
func (r MerkleRoot) Sum() (uint64, error) {
if len(r)%SmtRootSizeBytes == 0 {
return 0, fmt.Errorf("root#sum: not a merkle sum trie")
func (root MerkleSumRoot) Sum() (uint64, error) {
if err := root.validateBasic(); err != nil {
return 0, err
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return root.sum(), nil
}

// MustCount returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree. It panics if the root length
// is invalid.
func (root MerkleSumRoot) MustCount() uint64 {
count, err := root.Count()
if err != nil {
panic(err)
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:]), nil
return count
}

// Count returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree.
func (r MerkleRoot) Count() uint64 {
if len(r)%SmtRootSizeBytes == 0 {
panic("root#sum: not a merkle sum trie")
// count of the number of non-empty leafs in the tree. It returns an error if the
// root length is invalid.
func (root MerkleSumRoot) Count() (uint64, error) {
if err := root.validateBasic(); err != nil {
return 0, err
}

_, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return root.count(), nil
}

// DigestSize returns the length of the digest portion of the root.
func (root MerkleSumRoot) DigestSize() int {
return len(root) - countSizeBytes - sumSizeBytes
}

// HasDigestSize returns true if the root hash (digest) length is the same as
bryanchriswhite marked this conversation as resolved.
Show resolved Hide resolved
// that of the size of the given hasher.
func (root MerkleSumRoot) HasDigestSize(size int) bool {
return root.DigestSize() == size
}

var countBz [countSizeBytes]byte
copy(countBz[:], []byte(r)[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
// validateBasic returns an error if the root (digest) length is not a power of two.
bryanchriswhite marked this conversation as resolved.
Show resolved Hide resolved
func (root MerkleSumRoot) validateBasic() error {
if !isPowerOfTwo(root.DigestSize()) {
return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length")
}

return nil
}

// sum returns the sum of the node stored in the root.
func (root MerkleSumRoot) sum() uint64 {
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root)

return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx])
}

// count returns the count of the node stored in the root.
func (root MerkleSumRoot) count() uint64 {
_, firstCountByteIdx := getFirstMetaByteIdx(root)

return binary.BigEndian.Uint64(root[firstCountByteIdx:])
}

// isPowerOfTwo function returns true if the input n is a power of 2
func isPowerOfTwo(n int) bool {
// A power of 2 has only one bit set in its binary representation
if n <= 0 {
return false
}
return (n & (n - 1)) == 0
}
100 changes: 59 additions & 41 deletions root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,82 @@ import (
"github.com/pokt-network/smt/kvstore/simplemap"
)

func TestMerkleRoot_TrieTypes(t *testing.T) {
func TestMerkleSumRoot_SumAndCountSuccess(t *testing.T) {
tests := []struct {
desc string
sumTree bool
hasher hash.Hash
expectedPanic string
desc string
hasher hash.Hash
}{
{
desc: "successfully: gets sum of sha256 hasher SMST",
sumTree: true,
hasher: sha256.New(),
expectedPanic: "",
desc: "sha256 hasher",
hasher: sha256.New(),
},
{
desc: "successfully: gets sum of sha512 hasher SMST",
sumTree: true,
hasher: sha512.New(),
expectedPanic: "",
desc: "sha512 hasher",
hasher: sha512.New(),
},
}

nodeStore := simplemap.NewSimpleMap()
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, nodeStore.ClearAll())
})
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}

sum, sumErr := trie.Sum()
require.NoError(t, sumErr)

count, countErr := trie.Count()
require.NoError(t, countErr)

require.EqualValues(t, uint64(45), sum)
require.EqualValues(t, uint64(10), count)
})
}
}

func TestMekleRoot_SumAndCountError(t *testing.T) {
tests := []struct {
desc string
hasher hash.Hash
}{
{
desc: "failure: panics for sha256 hasher SMT",
sumTree: false,
hasher: sha256.New(),
expectedPanic: "roo#sum: not a merkle sum trie",
desc: "sha256 hasher",
hasher: sha256.New(),
},
{
desc: "failure: panics for sha512 hasher SMT",
sumTree: false,
hasher: sha512.New(),
expectedPanic: "roo#sum: not a merkle sum trie",
desc: "sha512 hasher",
hasher: sha512.New(),
},
}

nodeStore := simplemap.NewSimpleMap()
for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, nodeStore.ClearAll())
})
if tt.sumTree {
trie := smt.NewSparseMerkleSumTrie(nodeStore, tt.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}
require.NotNil(t, trie.Sum())
require.EqualValues(t, 45, trie.Sum())
require.EqualValues(t, 10, trie.Count())

return
}
trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher)
for i := 0; i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))))
}
if panicStr := recover(); panicStr != nil {
require.Equal(t, tt.expectedPanic, panicStr)
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}

root := trie.Root()

// Mangle the root bytes.
root = root[:len(root)-1]

sum, sumErr := root.Sum()
require.Error(t, sumErr)
require.Equal(t, uint64(0), sum)

count, countErr := root.Count()
require.Error(t, countErr)
require.Equal(t, uint64(0), count)
})
}
}
50 changes: 30 additions & 20 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package smt
import (
"bytes"
"encoding/binary"
"fmt"
"hash"

"github.com/pokt-network/smt/kvstore"
Expand Down Expand Up @@ -170,39 +171,48 @@ func (smst *SMST) Commit() error {
}

// Root returns the root hash of the trie with the total sum bytes appended
func (smst *SMST) Root() MerkleRoot {
return smst.SMT.Root() // [digest]+[binary sum]
func (smst *SMST) Root() MerkleSumRoot {
return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count]
}

// Sum returns the sum of the entire trie stored in the root.
// MustSum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will panic.
func (smst *SMST) Sum() uint64 {
rootDigest := []byte(smst.Root())
func (smst *SMST) MustSum() uint64 {
sum, err := smst.Sum()
if err != nil {
panic(err)
}
return sum
}

// Sum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will return an error.
func (smst *SMST) Sum() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
return smst.Root().Sum()
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:])
// MustCount returns the number of non-empty nodes in the entire trie stored in the root.
bryanchriswhite marked this conversation as resolved.
Show resolved Hide resolved
// If the tree is not a sum tree, it will panic.
func (smst *SMST) MustCount() uint64 {
count, err := smst.Count()
if err != nil {
panic(err)
}
return count
}

// Count returns the number of non-empty nodes in the entire trie stored in the root.
bryanchriswhite marked this conversation as resolved.
Show resolved Hide resolved
func (smst *SMST) Count() uint64 {
rootDigest := []byte(smst.Root())

// If the tree is not a sum tree, it will return an error.
func (smst *SMST) Count() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

_, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)

var countBz [countSizeBytes]byte
copy(countBz[:], rootDigest[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
return smst.Root().Count()
}

// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte
Expand All @@ -211,5 +221,5 @@ func (smst *SMST) Count() uint64 {
func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) {
firstCountByteIdx = len(data) - countSizeBytes
firstSumByteIdx = firstCountByteIdx - sumSizeBytes
return
return firstSumByteIdx, firstCountByteIdx
}
6 changes: 3 additions & 3 deletions smst_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestExampleSMST(t *testing.T) {
_ = trie.Commit()

// Calculate the total sum of the trie
_ = trie.Sum() // 20
_ = trie.MustSum() // 20

// Generate a Merkle proof for "foo"
proof1, _ := trie.Prove([]byte("foo"))
Expand All @@ -52,8 +52,8 @@ func TestExampleSMST(t *testing.T) {
require.False(t, valid_false1)

// Verify the total sum of the trie
require.EqualValues(t, 20, trie.Sum())
require.EqualValues(t, 20, trie.MustSum())

// Verify the number of non-empty leafs in the trie
require.EqualValues(t, 3, trie.Count())
require.EqualValues(t, 3, trie.MustCount())
}
Loading
Loading