From 1f84ef4532ef64e59a47ea112402e624ad18934d Mon Sep 17 00:00:00 2001 From: nugaon <50576770+nugaon@users.noreply.github.com> Date: Thu, 21 Sep 2023 15:02:05 +0200 Subject: [PATCH] fix(ph4): bmt (#4308) --- pkg/bmt/bmt.go | 13 +++++++++++++ pkg/bmt/proof.go | 39 +++++++++++++++++++++++++++++++++------ pkg/bmt/proof_test.go | 30 ++++++++++++++++++------------ pkg/bmt/trhasher.go | 25 ------------------------- pkg/storer/sample.go | 9 ++++++++- pkg/swarm/hasher.go | 10 +++++----- pkg/swarm/hasher_test.go | 2 +- 7 files changed, 78 insertions(+), 50 deletions(-) delete mode 100644 pkg/bmt/trhasher.go diff --git a/pkg/bmt/bmt.go b/pkg/bmt/bmt.go index e13aa3ec1cb..38c0e8bff5f 100644 --- a/pkg/bmt/bmt.go +++ b/pkg/bmt/bmt.go @@ -40,6 +40,19 @@ type Hasher struct { span []byte // The span of the data subsumed under the chunk } +// facade +func NewHasher(hasherFact func() hash.Hash) *Hasher { + conf := NewConf(hasherFact, swarm.BmtBranches, 32) + + return &Hasher{ + Conf: conf, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher), + } +} + // Capacity returns the maximum amount of bytes that will be processed by this hasher implementation. // since BMT assumes a balanced binary tree, capacity it is always a power of 2 func (h *Hasher) Capacity() int { diff --git a/pkg/bmt/proof.go b/pkg/bmt/proof.go index b9a958db9ab..fa39174c3c5 100644 --- a/pkg/bmt/proof.go +++ b/pkg/bmt/proof.go @@ -17,6 +17,17 @@ type Proof struct { Index int } +// Override base hash function of Hasher to fill buffer with zeros until chunk length +func (p Prover) Hash(b []byte) ([]byte, error) { + for i := p.size; i < p.maxSize; i += len(zerosection) { + _, err := p.Write(zerosection) + if err != nil { + return []byte{}, err + } + } + return p.Hasher.Hash(b) +} + // Proof returns the inclusion proof of the i-th data segment func (p Prover) Proof(i int) Proof { index := i @@ -36,26 +47,42 @@ func (p Prover) Proof(i int) Proof { secsize := 2 * p.segmentSize offset := i * secsize section := p.bmt.buffer[offset : offset+secsize] - return Proof{section, sisters, p.span, index} + left := section[:p.segmentSize] + right := section[p.segmentSize:] + var segment, firstSegmentSister []byte + if index%2 == 0 { + segment, firstSegmentSister = left, right + } else { + segment, firstSegmentSister = right, left + } + sisters = append([][]byte{firstSegmentSister}, sisters...) + return Proof{segment, sisters, p.span, index} } // Verify returns the bmt hash obtained from the proof which can then be checked against // the BMT hash of the chunk func (p Prover) Verify(i int, proof Proof) (root []byte, err error) { + var section []byte + if i%2 == 0 { + section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...) + } else { + section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...) + } i = i / 2 n := p.bmt.leaves[i] + hasher := p.hasher() isLeft := n.isLeft - root, err = doHash(n.hasher, proof.ProveSegment) + root, err = doHash(hasher, section) if err != nil { return nil, err } n = n.parent - for _, sister := range proof.ProofSegments { + for _, sister := range proof.ProofSegments[1:] { if isLeft { - root, err = doHash(n.hasher, root, sister) + root, err = doHash(hasher, root, sister) } else { - root, err = doHash(n.hasher, sister, root) + root, err = doHash(hasher, sister, root) } if err != nil { return nil, err @@ -63,7 +90,7 @@ func (p Prover) Verify(i int, proof Proof) (root []byte, err error) { isLeft = n.isLeft n = n.parent } - return sha3hash(proof.Span, root) + return doHash(hasher, proof.Span, root) } func (n *node) getSister(isLeft bool) []byte { diff --git a/pkg/bmt/proof_test.go b/pkg/bmt/proof_test.go index 337b1bf3420..1b7f6d3b3dd 100644 --- a/pkg/bmt/proof_test.go +++ b/pkg/bmt/proof_test.go @@ -20,7 +20,8 @@ func TestProofCorrectness(t *testing.T) { t.Parallel() testData := []byte("hello world") - testData = append(testData, make([]byte, 4096-len(testData))...) + testDataPadded := make([]byte, swarm.ChunkSize) + copy(testDataPadded, testData) verifySegments := func(t *testing.T, exp []string, found [][]byte) { t.Helper() @@ -57,8 +58,8 @@ func TestProofCorrectness(t *testing.T) { if err != nil { t.Fatal(err) } - - rh, err := hh.Hash(nil) + pr := bmt.Prover{hh} + rh, err := pr.Hash(nil) if err != nil { t.Fatal(err) } @@ -66,9 +67,10 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for left most", func(t *testing.T) { t.Parallel() - proof := bmt.Prover{hh}.Proof(0) + proof := pr.Proof(0) expSegmentStrings := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -79,7 +81,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testData[:2*hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) { t.Fatal("section incorrect") } @@ -91,9 +93,10 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for right most", func(t *testing.T) { t.Parallel() - proof := bmt.Prover{hh}.Proof(127) + proof := pr.Proof(127) expSegmentStrings := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -104,7 +107,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testData[126*hh.Size():]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) { t.Fatal("section incorrect") } @@ -116,9 +119,10 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for middle", func(t *testing.T) { t.Parallel() - proof := bmt.Prover{hh}.Proof(64) + proof := pr.Proof(64) expSegmentStrings := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -129,7 +133,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testData[64*hh.Size():66*hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) { t.Fatal("section incorrect") } @@ -142,6 +146,7 @@ func TestProofCorrectness(t *testing.T) { t.Parallel() segmentStrings := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -159,9 +164,9 @@ func TestProofCorrectness(t *testing.T) { segments = append(segments, decoded) } - segment := testData[64*hh.Size() : 66*hh.Size()] + segment := testDataPadded[64*hh.Size() : 65*hh.Size()] - rootHash, err := bmt.Prover{hh}.Verify(64, bmt.Proof{ + rootHash, err := pr.Verify(64, bmt.Proof{ ProveSegment: segment, ProofSegments: segments, Span: bmt.LengthToSpan(4096), @@ -200,6 +205,7 @@ func TestProof(t *testing.T) { } rh, err := hh.Hash(nil) + pr := bmt.Prover{hh} if err != nil { t.Fatal(err) } @@ -209,7 +215,7 @@ func TestProof(t *testing.T) { t.Run(fmt.Sprintf("segmentIndex %d", i), func(t *testing.T) { t.Parallel() - proof := bmt.Prover{hh}.Proof(i) + proof := pr.Proof(i) h := pool.Get() defer pool.Put(h) diff --git a/pkg/bmt/trhasher.go b/pkg/bmt/trhasher.go deleted file mode 100644 index 00df6664b85..00000000000 --- a/pkg/bmt/trhasher.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2023 The Swarm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package bmt - -import ( - "hash" - - "github.com/ethersphere/bee/pkg/swarm" -) - -func NewTrHasher(prefix []byte) *Hasher { - capacity := 32 - hasherFact := func() hash.Hash { return swarm.NewTrHasher(prefix) } - conf := NewConf(hasherFact, swarm.BmtBranches, capacity) - - return &Hasher{ - Conf: conf, - result: make(chan []byte), - errc: make(chan error, 1), - span: make([]byte, SpanSize), - bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher), - } -} diff --git a/pkg/storer/sample.go b/pkg/storer/sample.go index 690e06affc3..07c92885ac0 100644 --- a/pkg/storer/sample.go +++ b/pkg/storer/sample.go @@ -10,6 +10,7 @@ import ( "crypto/hmac" "encoding/binary" "fmt" + "hash" "math/big" "sort" "sync" @@ -43,7 +44,13 @@ type Sample struct { func RandSample(t *testing.T, anchor []byte) Sample { t.Helper() - hasher := bmt.NewTrHasher(anchor) + prefixHasherFactory := func() hash.Hash { + return swarm.NewPrefixHasher(anchor) + } + pool := bmt.NewPool(bmt.NewConf(prefixHasherFactory, swarm.BmtBranches, 8)) + + hasher := pool.Get() + defer pool.Put(hasher) items := make([]SampleItem, SampleSize) for i := 0; i < SampleSize; i++ { diff --git a/pkg/swarm/hasher.go b/pkg/swarm/hasher.go index 485b61ab398..b9823bb50a1 100644 --- a/pkg/swarm/hasher.go +++ b/pkg/swarm/hasher.go @@ -15,15 +15,15 @@ func NewHasher() hash.Hash { return sha3.NewLegacyKeccak256() } -type trHasher struct { +type PrefixHasher struct { hash.Hash prefix []byte } -// NewTrHasher returns new hasher which is Keccak-256 hasher +// NewPrefixHasher returns new hasher which is Keccak-256 hasher // with prefix value added as initial data. -func NewTrHasher(prefix []byte) hash.Hash { - h := &trHasher{ +func NewPrefixHasher(prefix []byte) hash.Hash { + h := &PrefixHasher{ Hash: NewHasher(), prefix: prefix, } @@ -32,7 +32,7 @@ func NewTrHasher(prefix []byte) hash.Hash { return h } -func (h *trHasher) Reset() { +func (h *PrefixHasher) Reset() { h.Hash.Reset() _, _ = h.Write(h.prefix) } diff --git a/pkg/swarm/hasher_test.go b/pkg/swarm/hasher_test.go index bfee0a78e98..3811e09605a 100644 --- a/pkg/swarm/hasher_test.go +++ b/pkg/swarm/hasher_test.go @@ -66,7 +66,7 @@ func TestNewTrHasher(t *testing.T) { // Run tests cases against TrHasher for _, tc := range tests { - h := swarm.NewTrHasher(tc.prefix) + h := swarm.NewPrefixHasher(tc.prefix) _, err := h.Write(tc.plaintext) if err != nil {