Skip to content

Commit

Permalink
Improved how placeholder values are maintained and how the node is re…
Browse files Browse the repository at this point in the history
…solved
  • Loading branch information
Olshansk committed Feb 11, 2024
1 parent cd24a78 commit 03d85ea
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 85 deletions.
2 changes: 1 addition & 1 deletion hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) {

func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte) {
dataWithoutSum := data[:len(data)-sumSizeBits]
leftData = dataWithoutSum[len(innerNodePrefix) : th.hashSize()+sumSizeBits+len(innerNodePrefix)]
leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBits]
rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBits:]
return
}
Expand Down
28 changes: 14 additions & 14 deletions node_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
// NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`.
// It was abbreviated to `data` for brevity.

// TODO_TECHDEBT: We can easily use `iota` and ENUMS to create a wait to have
// more expressive code, and leverage switches statements throughout.
var (
leafNodePrefix = []byte{0}
innerNodePrefix = []byte{1}
Expand Down Expand Up @@ -114,22 +116,10 @@ func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byt

// encodeSumInnerNode encodes an inner node for an smst given the data for both children
func encodeSumInnerNode(leftData, rightData []byte) (data []byte) {
// Retrieve the sum of the left subtree
leftSum := uint64(0)
leftSumBz := leftData[len(leftData)-sumSizeBits:]
if !bytes.Equal(leftSumBz, defaultEmptySum[:]) {
leftSum = binary.BigEndian.Uint64(leftSumBz)
}

// Retrieve the sum of the right subtree
rightSum := uint64(0)
rightSumBz := rightData[len(rightData)-sumSizeBits:]
if !bytes.Equal(rightSumBz, defaultEmptySum[:]) {
rightSum = binary.BigEndian.Uint64(rightSumBz)
}

// Compute the sum of the current node
var sum [sumSizeBits]byte
leftSum := parseSum(leftData)
rightSum := parseSum(rightData)
binary.BigEndian.PutUint64(sum[:], leftSum+rightSum)

// Prepare and return the encoded inner node data
Expand Down Expand Up @@ -157,3 +147,13 @@ func checkPrefix(data, prefix []byte) {
panic("invalid prefix")
}
}

// parseSum parses the sum from the encoded node data
func parseSum(data []byte) uint64 {
sum := uint64(0)
sumBz := data[len(data)-sumSizeBits:]
if !bytes.Equal(sumBz, defaultEmptySum[:]) {
sum = binary.BigEndian.Uint64(sumBz)
}
return sum
}
6 changes: 3 additions & 3 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, v
var currentHash, currentData []byte
if bytes.Equal(value, defaultEmptyValue) { // Non-membership proof.
if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value.
currentHash = placeholder(spec)
currentHash = spec.placeholder()
} else { // Leaf is an unrelated leaf.
var actualPath, valueHash []byte
actualPath, valueHash = parseLeafNode(proof.NonMembershipLeafData, spec.ph)
Expand Down Expand Up @@ -412,7 +412,7 @@ func CompactProof(proof *SparseMerkleProof, spec *TrieSpec) (*SparseCompactMerkl
for i := 0; i < len(proof.SideNodes); i++ {
node := make([]byte, hashSize(spec))
copy(node, proof.SideNodes[i])
if bytes.Equal(node, placeholder(spec)) {
if bytes.Equal(node, spec.placeholder()) {
setPathBit(bitMask, i)
} else {
compactedSideNodes = append(compactedSideNodes, node)
Expand All @@ -438,7 +438,7 @@ func DecompactProof(proof *SparseCompactMerkleProof, spec *TrieSpec) (*SparseMer
position := 0
for i := 0; i < proof.NumSideNodes; i++ {
if getPathBit(proof.BitMask, i) == 1 {
decompactedSideNodes[i] = placeholder(spec)
decompactedSideNodes[i] = spec.placeholder()
} else {
decompactedSideNodes[i] = proof.SideNodes[position]
position++
Expand Down
18 changes: 9 additions & 9 deletions smst_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ func exportToCSV(
nodeStore kvstore.MapStore,
) {
t.Helper()
// rootHash := smst.Root()
// rootNode, err := nodeStore.Get(rootHash)
// require.NoError(t, err)

// Testing
// fmt.Println(isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode))
// leftChild, rightChild := smst.Spec().th.parseInnerNode(rootNode)
// // fmt.Println(isExtNode(leftChild), isExtNode(rightChild), rightChild, leftChild)
// fmt.Println(leftChild[:1], isExtNode(leftChild), isInnerNode(leftChild), isLeafNode(leftChild))
rootHash := smst.Root()
rootNode, err := nodeStore.Get(rootHash)
require.NoError(t, err)

leftChild, rightChild := smst.Spec().th.parseSumInnerNode(rootNode)
fmt.Println("Prefix", "isExt", "isLeaf", "isInner")
fmt.Println(rootNode[:1], isExtNode(rootNode), isLeafNode(rootNode), isInnerNode(rootNode))
fmt.Println(leftChild[:1], isExtNode(leftChild), isLeafNode(leftChild), isInnerNode(leftChild))
fmt.Println(rightChild[:1], isExtNode(rightChild), isLeafNode(rightChild), isInnerNode(rightChild))
// path, value := parseLeafNode(rightChild, smst.Spec().ph)
// path2, value2 := parseLeafNode(leftChild, smst.Spec().ph)
// fmt.Println(path, "~~~", value, "~~~", path2, "~~~", value2)
Expand Down
4 changes: 2 additions & 2 deletions smst_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestSMST_Proof_Operations(t *testing.T) {
proof, err = smst.Prove([]byte("testKey3"))
require.NoError(t, err)
checkCompactEquivalence(t, proof, base)
result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultEmptyValue, 0, base)
result, err = VerifySumProof(proof, base.placeholder(), []byte("testKey3"), defaultEmptyValue, 0, base)
require.NoError(t, err)
require.True(t, result)
result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base)
Expand Down Expand Up @@ -377,7 +377,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) {
Path: path[:],
FlippedBits: []int{0},
Depth: 0,
ClosestPath: placeholder(smst.Spec()),
ClosestPath: smst.placeholder(),
ClosestProof: &SparseMerkleProof{},
})

Expand Down
121 changes: 77 additions & 44 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func (smt *SMT) ProveClosest(path []byte) (

// Retrieve the closest path and value hash if found
if node == nil { // trie was empty
proof.ClosestPath, proof.ClosestValueHash = placeholder(smt.Spec()), nil
proof.ClosestPath, proof.ClosestValueHash = smt.placeholder(), nil
proof.ClosestProof = &SparseMerkleProof{}
return proof, nil
}
Expand Down Expand Up @@ -529,66 +529,99 @@ func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) {
return node, nil
}
if smt.sumTrie {
return smt.resolveSum(stub.digest)
return smt.resolveSumNode(stub.digest)
}
return smt.resolve(stub.digest)
return smt.resolveNode(stub.digest)
}

func (smt *SMT) resolve(hash []byte) (trieNode, error) {
if bytes.Equal(smt.th.placeholder(), hash) {
// resolveNode returns a trieNode (inner, leaf, or extension) based on what they
// keyHash points to.
func (smt *SMT) resolveNode(digest []byte) (trieNode, error) {
// Check if the keyHash is the empty zero value of an empty subtree
if bytes.Equal(smt.placeholder(), digest) {
return nil, nil
}
data, err := smt.nodes.Get(hash)

// Retrieve the encoded noe data
data, err := smt.nodes.Get(digest)
if err != nil {
return nil, err
}

// Return the appropriate node type based on the first byte of the data
if isLeafNode(data) {
leaf := leafNode{persisted: true, digest: hash}
leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph)
return &leaf, nil
}
if isExtNode(data) {
extNode := extensionNode{persisted: true, digest: hash}
pathBounds, path, childHash := parseExtNode(data, smt.ph)
extNode.path = path
copy(extNode.pathBounds[:], pathBounds)
extNode.child = &lazyNode{childHash}
return &extNode, nil
}
leftHash, rightHash := smt.th.parseInnerNode(data)
inner := innerNode{persisted: true, digest: hash}
inner.leftChild = &lazyNode{leftHash}
inner.rightChild = &lazyNode{rightHash}
return &inner, nil
path, valueHash := parseLeafNode(data, smt.ph)
return &leafNode{
path: path,
valueHash: valueHash,
persisted: true,
digest: digest,
}, nil
} else if isExtNode(data) {
pathBounds, path, childData := parseExtNode(data, smt.ph)
return &extensionNode{
path: path,
pathBounds: [2]byte(pathBounds),
child: &lazyNode{childData},
persisted: true,
digest: digest,
}, nil
} else if isInnerNode(data) {
leftData, rightData := smt.th.parseInnerNode(data)
return &innerNode{
leftChild: &lazyNode{leftData},
rightChild: &lazyNode{rightData},
persisted: true,
digest: digest,
}, nil
} else {
panic("invalid node type")
}
}

// resolveSum resolves
func (smt *SMT) resolveSum(hash []byte) (trieNode, error) {
if bytes.Equal(placeholder(smt.Spec()), hash) {
// resolveNode returns a trieNode (inner, leaf, or extension) based on what they
// keyHash points to.
func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) {
// Check if the keyHash is the empty zero value of an empty subtree
if bytes.Equal(smt.placeholder(), digest) {
return nil, nil
}
data, err := smt.nodes.Get(hash)

// Retrieve the encoded noe data
data, err := smt.nodes.Get(digest)
if err != nil {
return nil, err
}

// Return the appropriate node type based on the first byte of the data
if isLeafNode(data) {
leaf := leafNode{persisted: true, digest: hash}
leaf.path, leaf.valueHash = parseLeafNode(data, smt.ph)
return &leaf, nil
}
if isExtNode(data) {
extNode := extensionNode{persisted: true, digest: hash}
pathBounds, path, childHash, _ := parseSumExtNode(data, smt.ph)
extNode.path = path
copy(extNode.pathBounds[:], pathBounds)
extNode.child = &lazyNode{childHash}
return &extNode, nil
}
leftHash, rightHash := smt.th.parseSumInnerNode(data)
inner := innerNode{persisted: true, digest: hash}
inner.leftChild = &lazyNode{leftHash}
inner.rightChild = &lazyNode{rightHash}
return &inner, nil
path, valueHash := parseLeafNode(data, smt.ph)
return &leafNode{
path: path,
valueHash: valueHash,
persisted: true,
digest: digest,
}, nil
} else if isExtNode(data) {
pathBounds, path, childData, _ := parseSumExtNode(data, smt.ph)
return &extensionNode{
path: path,
pathBounds: [2]byte(pathBounds),
child: &lazyNode{childData},
persisted: true,
digest: digest,
}, nil
} else if isInnerNode(data) {
leftData, rightData := smt.th.parseSumInnerNode(data)
return &innerNode{
leftChild: &lazyNode{leftData},
rightChild: &lazyNode{rightData},
persisted: true,
digest: digest,
}, nil
} else {
panic("invalid node type")
}
}

// Commit persists all dirty nodes in the trie, deletes all orphaned
Expand Down
2 changes: 1 addition & 1 deletion smt_proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) {
Path: path[:],
FlippedBits: []int{0},
Depth: 0,
ClosestPath: placeholder(smt.Spec()),
ClosestPath: smt.placeholder(),
ClosestProof: &SparseMerkleProof{},
})

Expand Down
12 changes: 11 additions & 1 deletion types.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ func (spec *TrieSpec) Spec() *TrieSpec {
return spec
}

// placeholder returns the default placeholder value depending on the trie type
func (spec *TrieSpec) placeholder() []byte {
if spec.sumTrie {
placeholder := spec.th.placeholder()
placeholder = append(placeholder, defaultEmptySum[:]...)
return placeholder
}
return spec.th.placeholder()
}

// depth returns the maximum depth of the trie.
// Since this tree is a binary tree, the depth is the number of bits in the path
// TODO_IN_THIS_PR: Try to understand why we're not taking the log of the output
Expand Down Expand Up @@ -190,7 +200,7 @@ func (spec *TrieSpec) sumSerialize(node trieNode) (preImage []byte) {
// digest = [node hash]+[8 byte sum]
func (spec *TrieSpec) hashSumNode(node trieNode) []byte {
if node == nil {
return placeholder(spec)
return spec.placeholder()
}
var cache *[]byte
switch n := node.(type) {
Expand Down
10 changes: 0 additions & 10 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ func bytesToInt(bz []byte) int {
return int(u)
}

// placeholder returns the default placeholder value depending on the trie type
func placeholder(spec *TrieSpec) []byte {
if spec.sumTrie {
placeholder := spec.th.placeholder()
placeholder = append(placeholder, defaultEmptySum[:]...)
return placeholder
}
return spec.th.placeholder()
}

// hashSize returns the hash size depending on the trie type
func hashSize(spec *TrieSpec) int {
if spec.sumTrie {
Expand Down

0 comments on commit 03d85ea

Please sign in to comment.