Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ph4): bmt #4308

Merged
merged 8 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/bmt/bmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ type Hasher struct {
span []byte // The span of the data subsumed under the chunk
}

// facade
func NewHasher(hasherFact func() hash.Hash) *Hasher {
conf := NewConf(hasherFact, swarm.BmtBranches, 32)

return &Hasher{
Conf: conf,
result: make(chan []byte),
errc: make(chan error, 1),
span: make([]byte, SpanSize),
bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher),
}
}

// Capacity returns the maximum amount of bytes that will be processed by this hasher implementation.
// since BMT assumes a balanced binary tree, capacity it is always a power of 2
func (h *Hasher) Capacity() int {
Expand Down
39 changes: 33 additions & 6 deletions pkg/bmt/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ type Proof struct {
Index int
}

// Override base hash function of Hasher to fill buffer with zeros until chunk length
func (p Prover) Hash(b []byte) ([]byte, error) {
for i := p.size; i < p.maxSize; i += len(zerosection) {
_, err := p.Write(zerosection)
if err != nil {
return []byte{}, err
}
}
return p.Hasher.Hash(b)
}

// Proof returns the inclusion proof of the i-th data segment
func (p Prover) Proof(i int) Proof {
index := i
Expand All @@ -36,34 +47,50 @@ func (p Prover) Proof(i int) Proof {
secsize := 2 * p.segmentSize
offset := i * secsize
section := p.bmt.buffer[offset : offset+secsize]
return Proof{section, sisters, p.span, index}
left := section[:p.segmentSize]
right := section[p.segmentSize:]
var segment, firstSegmentSister []byte
if index%2 == 0 {
segment, firstSegmentSister = left, right
} else {
segment, firstSegmentSister = right, left
}
sisters = append([][]byte{firstSegmentSister}, sisters...)
return Proof{segment, sisters, p.span, index}
}

// Verify returns the bmt hash obtained from the proof which can then be checked against
// the BMT hash of the chunk
func (p Prover) Verify(i int, proof Proof) (root []byte, err error) {
var section []byte
if i%2 == 0 {
section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...)
} else {
section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...)
}
i = i / 2
n := p.bmt.leaves[i]
hasher := p.hasher()
isLeft := n.isLeft
root, err = doHash(n.hasher, proof.ProveSegment)
root, err = doHash(hasher, section)
if err != nil {
return nil, err
}
n = n.parent

for _, sister := range proof.ProofSegments {
for _, sister := range proof.ProofSegments[1:] {
if isLeft {
root, err = doHash(n.hasher, root, sister)
root, err = doHash(hasher, root, sister)
} else {
root, err = doHash(n.hasher, sister, root)
root, err = doHash(hasher, sister, root)
}
if err != nil {
return nil, err
}
isLeft = n.isLeft
n = n.parent
}
return sha3hash(proof.Span, root)
return doHash(hasher, proof.Span, root)
}

func (n *node) getSister(isLeft bool) []byte {
Expand Down
30 changes: 18 additions & 12 deletions pkg/bmt/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ func TestProofCorrectness(t *testing.T) {
t.Parallel()

testData := []byte("hello world")
testData = append(testData, make([]byte, 4096-len(testData))...)
testDataPadded := make([]byte, swarm.ChunkSize)
copy(testDataPadded, testData)

verifySegments := func(t *testing.T, exp []string, found [][]byte) {
t.Helper()
Expand Down Expand Up @@ -57,18 +58,19 @@ func TestProofCorrectness(t *testing.T) {
if err != nil {
t.Fatal(err)
}

rh, err := hh.Hash(nil)
pr := bmt.Prover{hh}
rh, err := pr.Hash(nil)
if err != nil {
t.Fatal(err)
}

t.Run("proof for left most", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(0)
proof := pr.Proof(0)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -79,7 +81,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[:2*hh.Size()]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) {
t.Fatal("section incorrect")
}

Expand All @@ -91,9 +93,10 @@ func TestProofCorrectness(t *testing.T) {
t.Run("proof for right most", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(127)
proof := pr.Proof(127)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -104,7 +107,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[126*hh.Size():]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) {
t.Fatal("section incorrect")
}

Expand All @@ -116,9 +119,10 @@ func TestProofCorrectness(t *testing.T) {
t.Run("proof for middle", func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(64)
proof := pr.Proof(64)

expSegmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -129,7 +133,7 @@ func TestProofCorrectness(t *testing.T) {

verifySegments(t, expSegmentStrings, proof.ProofSegments)

if !bytes.Equal(proof.ProveSegment, testData[64*hh.Size():66*hh.Size()]) {
if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) {
t.Fatal("section incorrect")
}

Expand All @@ -142,6 +146,7 @@ func TestProofCorrectness(t *testing.T) {
t.Parallel()

segmentStrings := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
Expand All @@ -159,9 +164,9 @@ func TestProofCorrectness(t *testing.T) {
segments = append(segments, decoded)
}

segment := testData[64*hh.Size() : 66*hh.Size()]
segment := testDataPadded[64*hh.Size() : 65*hh.Size()]

rootHash, err := bmt.Prover{hh}.Verify(64, bmt.Proof{
rootHash, err := pr.Verify(64, bmt.Proof{
ProveSegment: segment,
ProofSegments: segments,
Span: bmt.LengthToSpan(4096),
Expand Down Expand Up @@ -200,6 +205,7 @@ func TestProof(t *testing.T) {
}

rh, err := hh.Hash(nil)
pr := bmt.Prover{hh}
if err != nil {
t.Fatal(err)
}
Expand All @@ -209,7 +215,7 @@ func TestProof(t *testing.T) {
t.Run(fmt.Sprintf("segmentIndex %d", i), func(t *testing.T) {
t.Parallel()

proof := bmt.Prover{hh}.Proof(i)
proof := pr.Proof(i)

h := pool.Get()
defer pool.Put(h)
Expand Down
25 changes: 0 additions & 25 deletions pkg/bmt/trhasher.go

This file was deleted.

9 changes: 8 additions & 1 deletion pkg/storer/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/hmac"
"encoding/binary"
"fmt"
"hash"
"math/big"
"sort"
"sync"
Expand Down Expand Up @@ -43,7 +44,13 @@ type Sample struct {
func RandSample(t *testing.T, anchor []byte) Sample {
t.Helper()

hasher := bmt.NewTrHasher(anchor)
prefixHasherFactory := func() hash.Hash {
return swarm.NewPrefixHasher(anchor)
}
pool := bmt.NewPool(bmt.NewConf(prefixHasherFactory, swarm.BmtBranches, 8))

hasher := pool.Get()
defer pool.Put(hasher)

items := make([]SampleItem, SampleSize)
for i := 0; i < SampleSize; i++ {
Expand Down
10 changes: 5 additions & 5 deletions pkg/swarm/hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ func NewHasher() hash.Hash {
return sha3.NewLegacyKeccak256()
}

type trHasher struct {
type PrefixHasher struct {
hash.Hash
prefix []byte
}

// NewTrHasher returns new hasher which is Keccak-256 hasher
// NewPrefixHasher returns new hasher which is Keccak-256 hasher
// with prefix value added as initial data.
func NewTrHasher(prefix []byte) hash.Hash {
h := &trHasher{
func NewPrefixHasher(prefix []byte) hash.Hash {
h := &PrefixHasher{
Hash: NewHasher(),
prefix: prefix,
}
Expand All @@ -32,7 +32,7 @@ func NewTrHasher(prefix []byte) hash.Hash {
return h
}

func (h *trHasher) Reset() {
func (h *PrefixHasher) Reset() {
h.Hash.Reset()
_, _ = h.Write(h.prefix)
}
2 changes: 1 addition & 1 deletion pkg/swarm/hasher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestNewTrHasher(t *testing.T) {

// Run tests cases against TrHasher
for _, tc := range tests {
h := swarm.NewTrHasher(tc.prefix)
h := swarm.NewPrefixHasher(tc.prefix)

_, err := h.Write(tc.plaintext)
if err != nil {
Expand Down
Loading