Skip to content

Commit

Permalink
fix(ph4): bmt (#4308)
Browse files Browse the repository at this point in the history
  • Loading branch information
nugaon authored Sep 21, 2023
1 parent 038dbfb commit 1f84ef4
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 50 deletions.
13 changes: 13 additions & 0 deletions pkg/bmt/bmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ type Hasher struct {
span []byte // The span of the data subsumed under the chunk
}

// facade
func NewHasher(hasherFact func() hash.Hash) *Hasher {
conf := NewConf(hasherFact, swarm.BmtBranches, 32)

return &Hasher{
Conf: conf,
result: make(chan []byte),
errc: make(chan error, 1),
span: make([]byte, SpanSize),
bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher),
}
}

// Capacity returns the maximum amount of bytes that will be processed by this hasher implementation.
// since BMT assumes a balanced binary tree, capacity it is always a power of 2
func (h *Hasher) Capacity() int {
Expand Down
39 changes: 33 additions & 6 deletions pkg/bmt/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ type Proof struct {
Index int
}

// Override base hash function of Hasher to fill buffer with zeros until chunk length
func (p Prover) Hash(b []byte) ([]byte, error) {
for i := p.size; i < p.maxSize; i += len(zerosection) {
_, err := p.Write(zerosection)
if err != nil {
return []byte{}, err
}
}
return p.Hasher.Hash(b)
}

// Proof returns the inclusion proof of the i-th data segment
func (p Prover) Proof(i int) Proof {
index := i
Expand All @@ -36,34 +47,50 @@ func (p Prover) Proof(i int) Proof {
secsize := 2 * p.segmentSize
offset := i * secsize
section := p.bmt.buffer[offset : offset+secsize]
return Proof{section, sisters, p.span, index}
left := section[:p.segmentSize]
right := section[p.segmentSize:]
var segment, firstSegmentSister []byte
if index%2 == 0 {
segment, firstSegmentSister = left, right
} else {
segment, firstSegmentSister = right, left
}
sisters = append([][]byte{firstSegmentSister}, sisters...)
return Proof{segment, sisters, p.span, index}
}

// Verify returns the bmt hash obtained from the proof which can then be checked against
// the BMT hash of the chunk
func (p Prover) Verify(i int, proof Proof) (root []byte, err error) {
var section []byte
if i%2 == 0 {
section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...)
} else {
section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...)
}
i = i / 2
n := p.bmt.leaves[i]
hasher := p.hasher()
isLeft := n.isLeft
root, err = doHash(n.hasher, proof.ProveSegment)
root, err = doHash(hasher, section)
if err != nil {
return nil, err
}
n = n.parent

for _, sister := range proof.ProofSegments {
for _, sister := range proof.ProofSegments[1:] {
if isLeft {
root, err = doHash(n.hasher, root, sister)
root, err = doHash(hasher, root, sister)
} else {
root, err = doHash(n.hasher, sister, root)
root, err = doHash(hasher, sister, root)
}
if err != nil {
return nil, err
}
isLeft = n.isLeft
n = n.parent
}
return sha3hash(proof.Span, root)
return doHash(hasher, proof.Span, root)
}

func (n *node) getSister(isLeft bool) []byte {
Expand Down
30 changes: 18 additions & 12 deletions pkg/bmt/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ func TestProofCorrectness(t *testing.T) {
t.Parallel()

testData := []byte("hello world")
testData = append(testData, make([]byte, 4096-len(testData))...)
testDataPadded := make([]byte, swarm.ChunkSize)
copy(testDataPadded, testData)

verifySegments := func(t *testing.T, exp []string, found [][]byte) {
t.Helper()
Expand Down Expand Up @@ -57,18 +58,19 @@ func TestProofCorrectness(t *testing.T) {
if err != nil {
t.Fatal(err)
}

rh, err := hh.Hash(nil)
pr := bmt.Prover{hh}
rh, err := pr.Hash(nil)
if err != nil {
t.Fatal(err)
}

t.Run("proof for left most", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(0)
proof := pr.Proof(0)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -79,7 +81,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[:2*hh.Size()]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) {
t.Fatal("section incorrect")
}

Expand All @@ -91,9 +93,10 @@ func TestProofCorrectness(t *testing.T) {
t.Run("proof for right most", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(127)
proof := pr.Proof(127)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -104,7 +107,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[126*hh.Size():]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) {
t.Fatal("section incorrect")
}

Expand All @@ -116,9 +119,10 @@ func TestProofCorrectness(t *testing.T) {
t.Run("proof for middle", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(64)
proof := pr.Proof(64)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -129,7 +133,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[64*hh.Size():66*hh.Size()]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) {
t.Fatal("section incorrect")
}

Expand All @@ -142,6 +146,7 @@ func TestProofCorrectness(t *testing.T) {
t.Parallel()

segmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -159,9 +164,9 @@ func TestProofCorrectness(t *testing.T) {
segments = append(segments, decoded)
}

segment := testData[64*hh.Size() : 66*hh.Size()]
segment := testDataPadded[64*hh.Size() : 65*hh.Size()]

rootHash, err := bmt.Prover{hh}.Verify(64, bmt.Proof{
rootHash, err := pr.Verify(64, bmt.Proof{
ProveSegment: segment,
ProofSegments: segments,
Span: bmt.LengthToSpan(4096),
Expand Down Expand Up @@ -200,6 +205,7 @@ func TestProof(t *testing.T) {
}

rh, err := hh.Hash(nil)
pr := bmt.Prover{hh}
if err != nil {
t.Fatal(err)
}
Expand All @@ -209,7 +215,7 @@ func TestProof(t *testing.T) {
t.Run(fmt.Sprintf("segmentIndex %d", i), func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(i)
proof := pr.Proof(i)

h := pool.Get()
defer pool.Put(h)
Expand Down
25 changes: 0 additions & 25 deletions pkg/bmt/trhasher.go

This file was deleted.

9 changes: 8 additions & 1 deletion pkg/storer/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/hmac"
"encoding/binary"
"fmt"
"hash"
"math/big"
"sort"
"sync"
Expand Down Expand Up @@ -43,7 +44,13 @@ type Sample struct {
func RandSample(t *testing.T, anchor []byte) Sample {
t.Helper()

hasher := bmt.NewTrHasher(anchor)
prefixHasherFactory := func() hash.Hash {
return swarm.NewPrefixHasher(anchor)
}
pool := bmt.NewPool(bmt.NewConf(prefixHasherFactory, swarm.BmtBranches, 8))

hasher := pool.Get()
defer pool.Put(hasher)

items := make([]SampleItem, SampleSize)
for i := 0; i < SampleSize; i++ {
Expand Down
10 changes: 5 additions & 5 deletions pkg/swarm/hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ func NewHasher() hash.Hash {
return sha3.NewLegacyKeccak256()
}

type trHasher struct {
type PrefixHasher struct {
hash.Hash
prefix []byte
}

// NewTrHasher returns new hasher which is Keccak-256 hasher
// NewPrefixHasher returns new hasher which is Keccak-256 hasher
// with prefix value added as initial data.
func NewTrHasher(prefix []byte) hash.Hash {
h := &trHasher{
func NewPrefixHasher(prefix []byte) hash.Hash {
h := &PrefixHasher{
Hash: NewHasher(),
prefix: prefix,
}
Expand All @@ -32,7 +32,7 @@ func NewTrHasher(prefix []byte) hash.Hash {
return h
}

func (h *trHasher) Reset() {
func (h *PrefixHasher) Reset() {
h.Hash.Reset()
_, _ = h.Write(h.prefix)
}
2 changes: 1 addition & 1 deletion pkg/swarm/hasher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestNewTrHasher(t *testing.T) {

// Run tests cases against TrHasher
for _, tc := range tests {
h := swarm.NewTrHasher(tc.prefix)
h := swarm.NewPrefixHasher(tc.prefix)

_, err := h.Write(tc.plaintext)
if err != nil {
Expand Down

0 comments on commit 1f84ef4

Please sign in to comment.