Skip to content

Commit

Permalink
Fix BPT receipts
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Oct 9, 2023
1 parent 072255c commit b3b8f83
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 67 deletions.
133 changes: 66 additions & 67 deletions pkg/database/bpt/bpt_receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,87 +7,86 @@
package bpt

import (
"crypto/sha256"

"gitlab.com/accumulatenetwork/accumulate/pkg/errors"
"gitlab.com/accumulatenetwork/accumulate/pkg/types/merkle"
)

// collectReceipt
// A recursive routine that searches the BPT for the given chainID. Once it is
// found, the search unwinds and builds the receipt.
//
// Inputs:
// BIdx -- byte index into the key
// bit -- index to the bit
// node -- the node in the BPT where we have reached in our search so far
// key -- The key in the BPT we are looking for
func (b *BPT) collectReceipt(BIdx, bit byte, n *branch, key [32]byte, r *merkle.Receipt) (hash []byte) {
// Load the node and hope it doesn't fail
_ = n.load()
// GetReceipt constructs a receipt for the current state for the given key.
func (b *BPT) GetReceipt(key [32]byte) (*merkle.Receipt, error) {
const debug = false

var entry, other node // The node has a left or right entry that builds a tree.
bite := key[BIdx] // Get the byte for debugging.
right := bit&bite == 0 // Flag for going right or left up the tree depends on a bit in the key
entry = n.Right // Guess we are going right (that the current bit is 1)
other = n.Left // We will need the other path as well.
if !right { // If the bit isn't 1, then we are NOT going right
entry = n.Left // then go left
other = n.Right // and the right will be the other path
err := b.executePending()
if err != nil {
return nil, errors.UnknownError.Wrap(err)
}

value, ok := entry.(*leaf)
if ok {
if value.Key.Hash() == key {
r.Start = append(r.Start[:0], value.Hash[:]...)
if other != nil { // If other isn't nil, then add it to the node list of the receipt
h, _ := other.getHash()
r.Entries = append(r.Entries,
&merkle.ReceiptEntry{Hash: h[:], Right: !right})
}
return append([]byte{}, n.Hash[:]...) // Note that the node.Hash is combined with other if other != nil
}
return nil
}
nextNode, ok := entry.(*branch)
if !ok {
return nil
// Find the leaf node
n, err := b.getRoot().getLeaf(key)
if err != nil {
return nil, errors.UnknownError.Wrap(err)
}

// We have processed the current bit. Now move to the next bit.
// Increment the bit index. If the set bit is still in the byte, we are done.
// If the bit rolls out of the byte, then set the low order bit, and increment the byte index.
bit >>= 1 //
if bit == 0 { // performance. What we are doing is shifting the
bit = 0x80 // bit test up on each level of the Merkle tree. If the bit
BIdx++ // shifts out of a BIdx, we increment the BIdx and start over
}
// Walk up the tree
receipt := new(merkle.Receipt)
receipt.Start = n.Hash[:]
working := n.Hash
var br *branch
for n := node(n); ; n = br {
// Get the parent
switch n := n.(type) {
case *branch:
br = n.parent
case *leaf:
br = n.parent
default:
panic("invalid node")
}
if br == nil {
break
}

childhash := b.collectReceipt(BIdx, bit, nextNode, key, r)
if childhash == nil {
return nil
}
// Skip any branches that only have one side populated
nh, _ := n.getHash()
brh, _ := br.getHash()
if nh == brh {
continue
}

if other != nil {
// Add the hash to the receipt provided by the entry, and mark it right or not right (right flag)
h, _ := other.getHash()
r.Entries = append(r.Entries, &merkle.ReceiptEntry{Hash: h[:], Right: !right})
}
// Calculate the next entry
var entry *merkle.ReceiptEntry
switch n {
case br.Right:
h, _ := br.Left.getHash()
entry = &merkle.ReceiptEntry{Hash: h[:], Right: false}
if debug {
working = sha256.Sum256(append(h[:], working[:]...))
}

h, _ := n.getHash()
return h[:]
}
case br.Left:
h, _ := br.Right.getHash()
entry = &merkle.ReceiptEntry{Hash: h[:], Right: true}
if debug {
working = sha256.Sum256(append(working[:], h[:]...))
}

// GetReceipt
// Returns the receipt for the current state for the given chainID
func (b *BPT) GetReceipt(key [32]byte) (*merkle.Receipt, error) { // The location of a value is determined by the chainID (a key)
err := b.executePending()
if err != nil {
return nil, errors.UnknownError.Wrap(err)
default:
panic("invalid tree")
}
if debug && brh != working {
panic("inconsistent BPT")
}

// Add it and move on
receipt.Entries = append(receipt.Entries, entry)
}

receipt := new(merkle.Receipt)
receipt.Anchor = b.collectReceipt(0, 0x80, b.getRoot(), key, receipt) //
if receipt.Anchor == nil {
return nil, errors.NotFound.WithFormat("key %x not found", key)
h, _ := b.getRoot().getHash()
receipt.Anchor = h[:]

if !receipt.Validate(nil) {
panic("constructed invalid receipt")
}
return receipt, nil
} //
}
55 changes: 55 additions & 0 deletions pkg/database/bpt/bpt_receipt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2023 The Accumulate Authors
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

package bpt

import (
"bytes"
"sort"
"testing"

"github.com/stretchr/testify/require"
"gitlab.com/accumulatenetwork/accumulate/internal/database/smt/common"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue/memory"
"gitlab.com/accumulatenetwork/accumulate/pkg/types/record"
)

// TestBPT_receipt
// Build a reasonable size BPT, then prove we can create a receipt for every
// element in said BPT.
func TestBPT_receipt(t *testing.T) {
kvs := memory.New(nil).Begin(nil, true)
store := keyvalue.RecordStore{Store: kvs}
bpt := newBPT(nil, nil, store, nil, "BPT")

numberEntries := 50000 // A pretty reasonable sized BPT

var keys, values common.RandHash // use the default sequence for keys
values.SetSeed([]byte{1, 2, 3}) // use a different sequence for values
for i := 0; i < numberEntries; i++ { // For the number of Entries specified for the BPT
chainID := keys.NextAList() // Get a key, keep a list
value := values.NextA() // Get some value (don't really care what it is)
err := bpt.Insert(record.KeyFromHash(chainID), value) // Insert the Key with the value into the BPT
require.NoError(t, err)
}
require.NoError(t, bpt.Commit())

keyList := append([][]byte{}, keys.List...) // Get all the keys
sort.Slice(keyList, func(i, j int) bool { return bytes.Compare(keyList[i], keyList[j]) < 0 }) //

// Recreate the BPT to throw away the pending list, otherwise GetReceipt
// spends a lot of time iterating over it and doing nothing
bpt = newBPT(nil, nil, store, nil, "BPT")

// Make sure every key we added to the BPT has a valid receipt
for i := range keys.List { // go through the list of keys
r, err := bpt.GetReceipt(keys.GetAElement(i))
require.NoError(t, err)
v := r.Validate(nil)
require.Truef(t, v, "should validate BPT element %d", i+1)
}
}
91 changes: 91 additions & 0 deletions pkg/database/bpt/bpt_savestate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2023 The Accumulate Authors
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

package bpt

import (
"crypto/sha256"
"fmt"
"io"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
"gitlab.com/accumulatenetwork/accumulate/internal/database/smt/common"
"gitlab.com/accumulatenetwork/accumulate/internal/database/smt/storage"
ioutil2 "gitlab.com/accumulatenetwork/accumulate/internal/util/io"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue/badger"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue/memory"
"gitlab.com/accumulatenetwork/accumulate/pkg/types/record"
)

func TestSaveState(t *testing.T) {

numberEntries := 5001 // A pretty reasonable sized BPT

DirName, err := os.MkdirTemp("", "AccDB")
require.Nil(t, err, "failed to create directory")
defer os.RemoveAll(DirName)

BDB, err := badger.New(DirName + "/add")
require.Nil(t, err, "failed to create db")
defer BDB.Close()

storeTx := BDB.Begin(nil, true) // and begin its use.
store := keyvalue.RecordStore{Store: storeTx} //
bpt := newBPT(nil, nil, store, nil, "BPT") // Create a BptManager. We will create a new one each cycle.
var keys, values common.RandHash // use the default sequence for keys
values.SetSeed([]byte{1, 2, 3}) // use a different sequence for values
for i := 0; i < numberEntries; i++ { // For the number of Entries specified for the BPT
chainID := keys.NextAList() // Get a key, keep a list
value := values.GetRandBuff(int(values.GetRandInt64() % 100))
hash := sha256.Sum256(value)
err := storeTx.Put(record.KeyFromHash(hash), value)
require.NoError(t, err)
err = bpt.Insert(record.KeyFromHash(chainID), hash) // Insert the Key with the value into the BPT
require.NoError(t, err)
}
require.NoError(t, bpt.Commit())
storeTx = BDB.Begin(nil, true)
store = keyvalue.RecordStore{Store: storeTx}
bpt = newBPT(nil, nil, store, nil, "BPT")

f, err := os.Create(filepath.Join(DirName, "SnapShot"))
require.NoError(t, err)
defer f.Close()

err = SaveSnapshotV1(bpt, f, func(key storage.Key, hash [32]byte) ([]byte, error) {
return storeTx.Get(record.KeyFromHash(hash))
})
require.NoError(t, err)

_, err = f.Seek(0, io.SeekStart)
require.NoError(t, err)

kvs2 := memory.New(nil).Begin(nil, true)
store2 := keyvalue.RecordStore{Store: kvs2}
bpt2 := newBPT(nil, nil, store2, nil, "BPT")
err = LoadSnapshotV1(bpt2, f, func(key storage.Key, hash [32]byte, reader ioutil2.SectionReader) error {
value, err := io.ReadAll(reader)
if err != nil {
return err
}
valueHash := sha256.Sum256(value)
if hash != valueHash {
return fmt.Errorf("hash does not match for key %X", key)
}
return nil
})
require.NoError(t, err)
require.NoError(t, bpt.Commit())
r1, err := bpt.GetRootHash()
require.NoError(t, err)
r2, err := bpt2.GetRootHash()
require.NoError(t, err)
require.Equal(t, r1, r2)
}
65 changes: 65 additions & 0 deletions pkg/database/bpt/iterate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2023 The Accumulate Authors
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

package bpt

import (
"bytes"
"fmt"
"sort"
"testing"

"github.com/stretchr/testify/require"
"gitlab.com/accumulatenetwork/accumulate/internal/database/smt/common"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue"
"gitlab.com/accumulatenetwork/accumulate/pkg/database/keyvalue/memory"
"gitlab.com/accumulatenetwork/accumulate/pkg/types/record"
)

var _ = fmt.Print

func TestGetRange(t *testing.T) {
for i := 0; i < 15000; i += 1017 {
GetRangeFor(t, i, 13)
}
}

func GetRangeFor(t *testing.T, numberEntries, rangeNum int) {
fmt.Println(numberEntries)
kvs := memory.New(nil).Begin(nil, true) // Build a BPT
store := keyvalue.RecordStore{Store: kvs} //
bpt := newBPT(nil, nil, store, nil, "BPT") //
var keys, values common.RandHash // use the default sequence for keys
values.SetSeed([]byte{1, 2, 3}) // use a different sequence for values
for i := 0; i < numberEntries; i++ { // For the number of Entries specified for the BPT
chainID := keys.NextAList() // Get a key, keep a list
value := values.NextA() // Get some value (don't really care what it is)
err := bpt.Insert(record.KeyFromHash(chainID), value) // Insert the Key with the value into the BPT
require.NoError(t, err)
}

cnt := 0

// The BPT will sort the keys, so we take the list of keys we used, and sort them
sort.Slice(keys.List, func(i, j int) bool { return bytes.Compare(keys.List[i], keys.List[j]) > 0 })

// We ask for a range of rangeNum entries at a time.
it := bpt.Iterate(13)
for i := 0; i < numberEntries; i += rangeNum {
require.True(t, it.Next(), "Must be more BPT entries")
bptValues := it.Value()
if len(bptValues) == 0 {
break
}
for j, v := range bptValues {
k := keys.List[i+j]
require.Truef(t, bytes.Equal(v.Key[:], k), "i,j= %d:%d %02x should be %02x", i, j, v.Key[:2], k[:2])
}
cnt += len(bptValues)
}

require.Equal(t, len(keys.List), cnt)
}

0 comments on commit b3b8f83

Please sign in to comment.