diff --git a/bmt/bmt.go b/bmt/bmt.go index 18eab5a2bc..d20e2c1aa3 100644 --- a/bmt/bmt.go +++ b/bmt/bmt.go @@ -18,11 +18,17 @@ package bmt import ( + "context" + "encoding/binary" + "errors" "fmt" "hash" "strings" "sync" "sync/atomic" + + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" ) /* @@ -60,6 +66,10 @@ const ( PoolSize = 8 ) +var ( + zeroSpan = make([]byte, 8) +) + // BaseHasherFunc is a hash.Hash constructor function used for the base hash of the BMT. // implemented by Keccak256 SHA3 sha3.NewLegacyKeccak256 type BaseHasherFunc func() hash.Hash @@ -75,8 +85,10 @@ type BaseHasherFunc func() hash.Hash // the tree and itself in a state reusable for hashing a new chunk // - generates and verifies segment inclusion proofs (TODO:) type Hasher struct { - pool *TreePool // BMT resource pool - bmt *tree // prebuilt BMT resource for flowcontrol and proofs + pool *TreePool // BMT resource pool + bmt *tree // prebuilt BMT resource for flowcontrol and proofs + size int // bytes written to Hasher since last Reset() + cursor int // cursor to write to on next Write() call } // New creates a reusable BMT Hasher that @@ -276,14 +288,56 @@ func newTree(segmentSize, depth int, hashfunc func() hash.Hash) *tree { } } -// methods needed to implement hash.Hash +// Implements param.SectionWriter +func (h *Hasher) SetWriter(_ param.SectionWriterFunc) param.SectionWriter { + log.Warn("Synchasher does not currently support SectionWriter chaining") + return h +} + +// Implements param.SectionWriter +func (h *Hasher) SectionSize() int { + return h.pool.SegmentSize +} + +// Implements param.SectionWriter +func (h *Hasher) SetLength(length int) { +} + +// Implements param.SectionWriter +func (h *Hasher) SetSpan(length int) { + span := LengthToSpan(length) + h.getTree().span = span +} + +// Implements storage.SwarmHash +func (h *Hasher) SetSpanBytes(b []byte) { + t := h.getTree() + t.span = make([]byte, 8) + copy(t.span, b) +} + +// Implements param.SectionWriter +func (h *Hasher) Branches() int { + return h.pool.SegmentCount +} + +// Implements param.SectionWriter +func (h *Hasher) Init(_ context.Context, _ func(error)) { +} -// Size returns the size +// Size returns the digest size +// Implements hash.Hash in param.SectionWriter func (h *Hasher) Size() int { return h.pool.SegmentSize } +// Seek sets the section that will be written to on the next Write() +func (h *Hasher) SeekSection(offset int) { + h.cursor = offset +} + // BlockSize returns the block size +// Implements hash.Hash in param.SectionWriter func (h *Hasher) BlockSize() int { return 2 * h.pool.SegmentSize } @@ -293,31 +347,35 @@ func (h *Hasher) BlockSize() int { // hash.Hash interface Sum method appends the byte slice to the underlying // data before it calculates and returns the hash of the chunk // caller must make sure Sum is not called concurrently with Write, writeSection +// Implements hash.Hash in param.SectionWriter func (h *Hasher) Sum(b []byte) (s []byte) { t := h.getTree() + if h.size == 0 && t.offset == 0 { + h.releaseTree() + return h.pool.zerohashes[h.pool.Depth] + } // write the last section with final flag set to true go h.writeSection(t.cursor, t.section, true, true) // wait for the result s = <-t.result + if t.span == nil { + t.span = LengthToSpan(h.size) + } span := t.span // release the tree resource back to the pool h.releaseTree() - // b + sha3(span + BMT(pure_chunk)) - if len(span) == 0 { - return append(b, s...) - } return doSum(h.pool.hasher(), b, span, s) } -// methods needed to implement the SwarmHash and the io.Writer interfaces - // Write calls sequentially add to the buffer to be hashed, // with every full segment calls writeSection in a go routine +// Implements hash.Hash (io.Writer) in param.SectionWriter func (h *Hasher) Write(b []byte) (int, error) { l := len(b) if l == 0 || l > h.pool.Size { return 0, nil } + h.size += len(b) t := h.getTree() secsize := 2 * h.pool.SegmentSize // calculate length of missing bit to complete current open section @@ -359,20 +417,13 @@ func (h *Hasher) Write(b []byte) (int, error) { } // Reset needs to be called before writing to the hasher +// Implements hash.Hash in param.SectionWriter func (h *Hasher) Reset() { + h.cursor = 0 + h.size = 0 h.releaseTree() } -// methods needed to implement the SwarmHash interface - -// ResetWithLength needs to be called before writing to the hasher -// the argument is supposed to be the byte slice binary representation of -// the length of the data subsumed under the hash, i.e., span -func (h *Hasher) ResetWithLength(span []byte) { - h.Reset() - h.getTree().span = span -} - // releaseTree gives back the Tree to the pool whereby it unlocks // it resets tree, segment and index func (h *Hasher) releaseTree() { @@ -395,30 +446,30 @@ func (h *Hasher) releaseTree() { } // NewAsyncWriter extends Hasher with an interface for concurrent segment/section writes +// TODO: Instead of explicitly setting double size of segment should be dynamic and chunked internally. If not, we have to keep different bmt hashers generation functions for different purposes in the same instance, or cope with added complexity of bmt hasher generation functions having to receive parameters func (h *Hasher) NewAsyncWriter(double bool) *AsyncHasher { secsize := h.pool.SegmentSize if double { secsize *= 2 } + seccount := h.pool.SegmentCount + if double { + seccount /= 2 + } write := func(i int, section []byte, final bool) { h.writeSection(i, section, double, final) } return &AsyncHasher{ - Hasher: h, - double: double, - secsize: secsize, - write: write, + Hasher: h, + double: double, + secsize: secsize, + seccount: seccount, + write: write, + jobSize: 0, + sought: true, } } -// SectionWriter is an asynchronous segment/section writer interface -type SectionWriter interface { - Reset() // standard init to be called before reuse - Write(index int, data []byte) // write into section of index - Sum(b []byte, length int, span []byte) []byte // returns the hash of the buffer - SectionSize() int // size of the async section unit to use -} - // AsyncHasher extends BMT Hasher with an asynchronous segment/section writer interface // AsyncHasher is unsafe and does not check indexes and section data lengths // it must be used with the right indexes and length and the right number of sections @@ -434,33 +485,94 @@ type SectionWriter interface { // * it will not leak processes if not all sections are written but it blocks // and keeps the resource which can be released calling Reset() type AsyncHasher struct { - *Hasher // extends the Hasher - mtx sync.Mutex // to lock the cursor access - double bool // whether to use double segments (call Hasher.writeSection) - secsize int // size of base section (size of hash or double) - write func(i int, section []byte, final bool) + *Hasher // extends the Hasher + mtx sync.Mutex // to lock the cursor access + double bool // whether to use double segments (call Hasher.writeSection) + secsize int // size of base section (size of hash or double) + seccount int // base section count + write func(i int, section []byte, final bool) + errFunc func(error) + all bool // if all written in one go, temporary workaround + sought bool + jobSize int } -// methods needed to implement AsyncWriter +// Implements param.SectionWriter +// TODO context should be implemented all across (ie original TODO in TreePool.reserve()) +func (sw *AsyncHasher) Init(_ context.Context, errFunc func(error)) { + sw.errFunc = errFunc +} + +// Implements param.SectionWriter +func (sw *AsyncHasher) Reset() { + sw.sought = true + sw.jobSize = 0 + sw.all = false + sw.Hasher.Reset() +} + +func (sw *AsyncHasher) SetLength(length int) { + sw.jobSize = length +} + +// Implements param.SectionWriter +func (sw *AsyncHasher) SetWriter(_ param.SectionWriterFunc) param.SectionWriter { + sw.errFunc(errors.New("Asynchasher does not currently support SectionWriter chaining")) + return sw +} // SectionSize returns the size of async section unit to use +// Implements param.SectionWriter func (sw *AsyncHasher) SectionSize() int { return sw.secsize } +// DigestSize returns the branching factor, which is equivalent to the size of the BMT input +// Implements param.SectionWriter +func (sw *AsyncHasher) Branches() int { + return sw.seccount +} + +// SeekSection sets the cursor where the next Write() will write +// It locks the cursor until Write() is called; if no Write() is called, it will hang. +// Implements param.SectionWriter +func (sw *AsyncHasher) SeekSection(offset int) { + sw.mtx.Lock() + sw.Hasher.SeekSection(offset) +} + +// Write writes to the current position cursor of the Hasher +// The cursor must first be manually set with SeekSection() +// The method will NOT advance the cursor. +// Implements hash.hash in param.SectionWriter +func (sw *AsyncHasher) Write(section []byte) (int, error) { + defer sw.mtx.Unlock() + sw.Hasher.size += len(section) + return sw.writeSection(sw.Hasher.cursor, section) +} + // Write writes the i-th section of the BMT base // this function can and is meant to be called concurrently // it sets max segment threadsafely -func (sw *AsyncHasher) Write(i int, section []byte) { - sw.mtx.Lock() - defer sw.mtx.Unlock() +func (sw *AsyncHasher) writeSection(i int, section []byte) (int, error) { + // TODO: Temporary workaround for chunkwise write + if i < 0 { + sw.Hasher.cursor = 0 + sw.Hasher.Reset() + sw.Hasher.SetLength(len(section)) + sw.Hasher.Write(section) + sw.all = true + return len(section), nil + } + //sw.mtx.Lock() // this lock is now set in SeekSection + // defer sw.mtk.Unlock() // this unlock is still left in Write() t := sw.getTree() // cursor keeps track of the rightmost section written so far // if index is lower than cursor then just write non-final section as is if i < t.cursor { // if index is not the rightmost, safe to write section go sw.write(i, section, false) - return + return len(section), nil } // if there is a previous rightmost section safe to write section if t.offset > 0 { @@ -470,7 +582,7 @@ func (sw *AsyncHasher) Write(i int, section []byte) { t.section = make([]byte, sw.secsize) copy(t.section, section) go sw.write(i, t.section, true) - return + return len(section), nil } // the rightmost section just changed, so we write the previous one as non-final go sw.write(t.cursor, t.section, false) @@ -481,6 +593,7 @@ func (sw *AsyncHasher) Write(i int, section []byte) { t.offset = i*sw.secsize + 1 t.section = make([]byte, sw.secsize) copy(t.section, section) + return len(section), nil } // Sum can be called any time once the length and the span is known @@ -492,12 +605,20 @@ func (sw *AsyncHasher) Write(i int, section []byte) { // length: known length of the input (unsafe; undefined if out of range) // meta: metadata to hash together with BMT root for the final digest // e.g., span for protection against existential forgery -func (sw *AsyncHasher) Sum(b []byte, length int, meta []byte) (s []byte) { +// +// Implements hash.hash in param.SectionWriter +func (sw *AsyncHasher) Sum(b []byte) (s []byte) { + if sw.all { + return sw.Hasher.Sum(nil) + } sw.mtx.Lock() t := sw.getTree() + length := sw.jobSize if length == 0 { + sw.releaseTree() sw.mtx.Unlock() s = sw.pool.zerohashes[sw.pool.Depth] + return } else { // for non-zero input the rightmost section is written to the tree asynchronously // if the actual last section has been written (t.cursor == length/t.secsize) @@ -515,15 +636,13 @@ func (sw *AsyncHasher) Sum(b []byte, length int, meta []byte) (s []byte) { } // relesase the tree back to the pool sw.releaseTree() - // if no meta is given just append digest to b - if len(meta) == 0 { - return append(b, s...) - } + meta := t.span // hash together meta and BMT root hash using the pools return doSum(sw.pool.hasher(), b, meta, s) } // writeSection writes the hash of i-th section into level 1 node of the BMT tree +// TODO: h.size increases even on multiple writes to the same section of a section func (h *Hasher) writeSection(i int, section []byte, double bool, final bool) { // select the leaf node for the section var n *node @@ -688,3 +807,11 @@ func calculateDepthFor(n int) (d int) { } return d + 1 } + +// creates a binary span size representation +// to pass to bmt.SectionWriter +func LengthToSpan(length int) []byte { + spanBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(spanBytes, uint64(length)) + return spanBytes +} diff --git a/bmt/bmt_test.go b/bmt/bmt_test.go index fc020eb7c2..1cfd611a22 100644 --- a/bmt/bmt_test.go +++ b/bmt/bmt_test.go @@ -26,10 +26,15 @@ import ( "testing" "time" + "github.com/ethersphere/swarm/param" "github.com/ethersphere/swarm/testutil" "golang.org/x/crypto/sha3" ) +func init() { + testutil.Init() +} + // the actual data length generated (could be longer than max datalength of the BMT) const BufferSize = 4128 @@ -141,10 +146,10 @@ func TestHasherEmptyData(t *testing.T) { defer pool.Drain(0) bmt := New(pool) rbmt := NewRefHasher(hasher, count) - refHash := rbmt.Hash(data) - expHash := syncHash(bmt, nil, data) - if !bytes.Equal(expHash, refHash) { - t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) + expHash := rbmt.Hash(data) + resHash := syncHash(bmt, 0, data) + if !bytes.Equal(expHash, resHash) { + t.Fatalf("hash mismatch with reference. expected %x, got %x", resHash, expHash) } }) } @@ -197,15 +202,19 @@ func TestAsyncCorrectness(t *testing.T) { bmt := New(pool) d := data[:n] rbmt := NewRefHasher(hasher, count) - exp := rbmt.Hash(d) - got := syncHash(bmt, nil, d) + expNoMeta := rbmt.Hash(d) + h := hasher() + h.Write(zeroSpan) + h.Write(expNoMeta) + exp := h.Sum(nil) + got := syncHash(bmt, 0, d) if !bytes.Equal(got, exp) { - t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got) + t.Fatalf("wrong sync hash (syncpart) for datalength %v: expected %x (ref), got %x", n, exp, got) } sw := bmt.NewAsyncWriter(double) - got = asyncHashRandom(sw, nil, d, wh) + got = asyncHashRandom(sw, 0, d, wh) if !bytes.Equal(got, exp) { - t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got) + t.Fatalf("wrong async hash (asyncpart) for datalength %v: expected %x, got %x", n, exp, got) } } }) @@ -288,8 +297,12 @@ func TestBMTWriterBuffers(t *testing.T) { bmt := New(pool) data := testutil.RandomBytes(1, n) rbmt := NewRefHasher(hasher, count) - refHash := rbmt.Hash(data) - expHash := syncHash(bmt, nil, data) + refNoMetaHash := rbmt.Hash(data) + h := hasher() + h.Write(zeroSpan) + h.Write(refNoMetaHash) + refHash := h.Sum(nil) + expHash := syncHash(bmt, 0, data) if !bytes.Equal(expHash, refHash) { t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) } @@ -308,6 +321,7 @@ func TestBMTWriterBuffers(t *testing.T) { return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) } } + bmt.SetSpan(0) hash := bmt.Sum(nil) if !bytes.Equal(hash, expHash) { return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) @@ -346,11 +360,16 @@ func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, coun if len(d) < n { n = len(d) } - binary.BigEndian.PutUint64(span, uint64(n)) + binary.LittleEndian.PutUint64(span, uint64(n)) data := d[:n] rbmt := NewRefHasher(hasher, count) - exp := sha3hash(span, rbmt.Hash(data)) - got := syncHash(bmt, span, data) + var exp []byte + if n == 0 { + exp = bmt.pool.zerohashes[bmt.pool.Depth] + } else { + exp = sha3hash(span, rbmt.Hash(data)) + } + got := syncHash(bmt, n, data) if !bytes.Equal(got, exp) { return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) } @@ -460,7 +479,7 @@ func benchmarkBMT(t *testing.B, n int) { t.ReportAllocs() t.ResetTimer() for i := 0; i < t.N; i++ { - syncHash(bmt, nil, data) + syncHash(bmt, 0, data) } } @@ -478,7 +497,7 @@ func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) { t.ReportAllocs() t.ResetTimer() for i := 0; i < t.N; i++ { - asyncHash(bmt, nil, n, wh, idxs, segments) + asyncHash(bmt, 0, n, wh, idxs, segments) } } @@ -498,7 +517,7 @@ func benchmarkPool(t *testing.B, poolsize, n int) { go func() { defer wg.Done() bmt := New(pool) - syncHash(bmt, nil, data) + syncHash(bmt, 0, data) }() } wg.Wait() @@ -519,8 +538,9 @@ func benchmarkRefHasher(t *testing.B, n int) { } // Hash hashes the data and the span using the bmt hasher -func syncHash(h *Hasher, span, data []byte) []byte { - h.ResetWithLength(span) +func syncHash(h *Hasher, spanLength int, data []byte) []byte { + h.Reset() + h.SetSpan(spanLength) h.Write(data) return h.Sum(nil) } @@ -547,23 +567,27 @@ func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) { } // splits the input data performs a random shuffle to mock async section writes -func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) { +func asyncHashRandom(bmt param.SectionWriter, spanLength int, data []byte, wh whenHash) (s []byte) { idxs, segments := splitAndShuffle(bmt.SectionSize(), data) - return asyncHash(bmt, span, len(data), wh, idxs, segments) + return asyncHash(bmt, spanLength, len(data), wh, idxs, segments) } -// mock for async section writes for BMT SectionWriter +// mock for async section writes for param.SectionWriter // requires a permutation (a random shuffle) of list of all indexes of segments // and writes them in order to the appropriate section // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes]) -func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { +func asyncHash(bmt param.SectionWriter, spanLength int, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { bmt.Reset() if l == 0 { - return bmt.Sum(nil, l, span) + bmt.SetLength(l) + bmt.SetSpan(spanLength) + return bmt.Sum(nil) } c := make(chan []byte, 1) hashf := func() { - c <- bmt.Sum(nil, l, span) + bmt.SetLength(l) + bmt.SetSpan(spanLength) + c <- bmt.Sum(nil) } maxsize := len(idxs) var r int @@ -571,13 +595,35 @@ func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, s r = rand.Intn(maxsize) } for i, idx := range idxs { - bmt.Write(idx, segments[idx]) + bmt.SeekSection(idx) + bmt.Write(segments[idx]) if (wh == first || wh == random) && i == r { go hashf() } } if wh == last { - return bmt.Sum(nil, l, span) + bmt.SetLength(l) + bmt.SetSpan(spanLength) + return bmt.Sum(nil) } return <-c } + +// TestUseSyncAsOrdinaryHasher verifies that the bmt.Hasher can be used with the hash.Hash interface +func TestUseSyncAsOrdinaryHasher(t *testing.T) { + hasher := sha3.NewLegacyKeccak256 + pool := NewTreePool(hasher, segmentCount, PoolSize) + bmt := New(pool) + bmt.Write([]byte("foo")) + res := bmt.Sum(nil) + refh := NewRefHasher(hasher, 128) + resh := refh.Hash([]byte("foo")) + hsub := hasher() + span := LengthToSpan(3) + hsub.Write(span) + hsub.Write(resh) + refRes := hsub.Sum(nil) + if !bytes.Equal(res, refRes) { + t.Fatalf("normalhash; expected %x, got %x", refRes, res) + } +} diff --git a/file/encrypt/encrypt.go b/file/encrypt/encrypt.go new file mode 100644 index 0000000000..4ee570afd1 --- /dev/null +++ b/file/encrypt/encrypt.go @@ -0,0 +1,117 @@ +package encrypt + +import ( + "context" + crand "crypto/rand" + "fmt" + "hash" + + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/storage/encryption" + "golang.org/x/crypto/sha3" +) + +type Encrypt struct { + key []byte + e encryption.Encryption + w param.SectionWriter + length int + span int + keyHash hash.Hash + errFunc func(error) +} + +func New(key []byte, initCtr uint32, hashFunc param.SectionWriterFunc) (*Encrypt, error) { + if key == nil { + key = make([]byte, encryption.KeyLength) + c, err := crand.Read(key) + if err != nil { + return nil, err + } + if c < encryption.KeyLength { + return nil, fmt.Errorf("short read: %d", c) + } + } else if len(key) != encryption.KeyLength { + return nil, fmt.Errorf("encryption key must be %d bytes", encryption.KeyLength) + } + e := &Encrypt{ + e: encryption.New(key, 0, initCtr, sha3.NewLegacyKeccak256), + key: make([]byte, encryption.KeyLength), + keyHash: param.HashFunc(), + } + copy(e.key, key) + return e, nil +} + +func (e *Encrypt) SetWriter(hashFunc param.SectionWriterFunc) param.SectionWriter { + e.w = hashFunc(nil) + return e + +} + +func (e *Encrypt) Init(_ context.Context, errFunc func(error)) { + e.errFunc = errFunc +} + +func (e *Encrypt) SeekSection(offset int) { + e.w.SeekSection(offset) +} + +func (e *Encrypt) Write(b []byte) (int, error) { + cipherText, err := e.e.Encrypt(b) + if err != nil { + e.errFunc(err) + return 0, err + } + return e.w.Write(cipherText) +} + +func (e *Encrypt) Reset() { + e.w.Reset() +} + +func (e *Encrypt) SetLength(length int) { + e.length = length + e.w.SetLength(length) +} + +func (e *Encrypt) SetSpan(length int) { + e.span = length + e.w.SetSpan(length) +} + +func (e *Encrypt) Sum(b []byte) []byte { + // derive new key + oldKey := make([]byte, encryption.KeyLength) + copy(oldKey, e.key) + e.keyHash.Reset() + e.keyHash.Write(e.key) + newKey := e.keyHash.Sum(nil) + copy(e.key, newKey) + s := e.w.Sum(b) + log.Trace("key", "key", oldKey, "ekey", e.key, "newkey", newKey) + return append(oldKey, s...) +} + +// DigestSize implements param.SectionWriter +func (e *Encrypt) BlockSize() int { + return e.Size() +} + +// DigestSize implements param.SectionWriter +// TODO: cache these calculations +func (e *Encrypt) Size() int { + return e.w.Size() + encryption.KeyLength +} + +// SectionSize implements param.SectionWriter +func (e *Encrypt) SectionSize() int { + return e.w.SectionSize() +} + +// Branches implements param.SectionWriter +// TODO: cache these calculations +func (e *Encrypt) Branches() int { + return e.w.Branches() / (e.Size() / e.w.SectionSize()) +} diff --git a/file/encrypt/encrypt_test.go b/file/encrypt/encrypt_test.go new file mode 100644 index 0000000000..e0ab7673c7 --- /dev/null +++ b/file/encrypt/encrypt_test.go @@ -0,0 +1,225 @@ +package encrypt + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/file/hasher" + "github.com/ethersphere/swarm/file/testutillocal" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/storage/encryption" + "github.com/ethersphere/swarm/testutil" + "golang.org/x/crypto/sha3" +) + +const ( + sectionSize = 32 + branches = 128 + chunkSize = 4096 +) + +var ( + testKey = append(make([]byte, encryption.KeyLength-1), byte(0x2a)) +) + +func init() { + testutil.Init() +} + +func TestKey(t *testing.T) { + + hashFunc := testutillocal.NewBMTHasherFunc(0) + + e, err := New(nil, 42, hashFunc) + if err != nil { + t.Fatal(err) + } + if e.key == nil { + t.Fatalf("new key nil; expected not nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + errFunc := func(error) {} + cache := testutillocal.NewCache() + cache.Init(ctx, errFunc) + cacheFunc := func(_ context.Context) param.SectionWriter { + return cache + } + e, err = New(testKey, 42, cacheFunc) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testKey, e.key) { + t.Fatalf("key seed; expected %x, got %x", testKey, e.key) + } + e.SetWriter(cacheFunc) + + _, data := testutil.SerialData(chunkSize, 255, 0) + e.Write(data) // 0 + e.SetLength(chunkSize) + doubleRef := e.Sum(nil) + refKey := doubleRef[:encryption.KeyLength] + if !bytes.Equal(refKey, testKey) { + t.Fatalf("returned ref key, expected %x, got %x", testKey, refKey) + } + + correctNextKeyHex := "0xbeced09521047d05b8960b7e7bcc1d1292cf3e4b2a6b63f48335cbde5f7545d2" + nextKeyHex := hexutil.Encode(e.key) + if nextKeyHex != correctNextKeyHex { + t.Fatalf("key next; expected %s, got %s", correctNextKeyHex, nextKeyHex) + } +} + +func TestEncryptOneChunk(t *testing.T) { + + hashFunc := testutillocal.NewBMTHasherFunc(128 * 128) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + errFunc := func(error) {} + + cache := testutillocal.NewCache() + cache.Init(ctx, errFunc) + cache.SetWriter(hashFunc) + cacheFunc := func(_ context.Context) param.SectionWriter { + return cache + } + + encryptFunc := func(_ context.Context) param.SectionWriter { + eFunc, err := New(testKey, uint32(42), cacheFunc) + if err != nil { + t.Fatal(err) + } + eFunc.SetWriter(cacheFunc) + eFunc.Init(ctx, errFunc) + return eFunc + } + + _, data := testutil.SerialData(chunkSize, 255, 0) + h := hasher.New(encryptFunc) + h.Init(ctx, func(error) {}) + h.Write(data) //0 + doubleRef := h.Sum(nil) + + enc := encryption.New(testKey, 0, 42, sha3.NewLegacyKeccak256) + cipherText, err := enc.Encrypt(data) + if err != nil { + t.Fatal(err) + } + cacheText := cache.Get(0) + if !bytes.Equal(cipherText, cacheText) { + log.Trace("data mismatch", "expect", cipherText, "got", cacheText) + t.Fatalf("encrypt onechunk; data mismatch") + } + + bmtTreePool := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + hc := bmt.New(bmtTreePool) + hc.Reset() + hc.SetLength(len(cipherText)) + hc.Write(cipherText) + cipherRef := hc.Sum(nil) + dataRef := doubleRef[encryption.KeyLength:] + if !bytes.Equal(dataRef, cipherRef) { + t.Fatalf("encrypt ref; expected %x, got %x", cipherRef, dataRef) + } +} + +func TestEncryptChunkWholeAndSections(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(128 * 128) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + errFunc := func(error) {} + + cache := testutillocal.NewCache() + cache.Init(ctx, errFunc) + cache.SetWriter(hashFunc) + cacheFunc := func(_ context.Context) param.SectionWriter { + return cache + } + + e, err := New(testKey, uint32(42), cacheFunc) + if err != nil { + t.Fatal(err) + } + e.Init(ctx, errFunc) + + _, data := testutil.SerialData(chunkSize, 255, 0) + e.Write(data) // 0 + e.SetLength(chunkSize) + e.Sum(nil) + + cacheCopy := make([]byte, chunkSize) + copy(cacheCopy, cache.Get(0)) + cache.Delete(0) + + e, err = New(testKey, uint32(42), cacheFunc) + if err != nil { + t.Fatal(err) + } + e.Init(ctx, errFunc) + + for i := 0; i < chunkSize; i += sectionSize { + e.SeekSection(i / sectionSize) + e.Write(data[i : i+sectionSize]) + } + e.SetLength(chunkSize) + e.Sum(nil) + + for i := 0; i < chunkSize; i += sectionSize { + chunked := cacheCopy[i : i+sectionSize] + sectioned := cache.Get(i / sectionSize) + if !bytes.Equal(chunked, sectioned) { + t.Fatalf("encrypt chunk full and section idx %d; expected %x, got %x", i/sectionSize, chunked, sectioned) + } + } +} + +func TestEncryptIntermediateChunk(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(128 * 128) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + errFunc := func(err error) { + log.Error("filehasher pipeline error", "err", err) + cancel() + } + + cache := testutillocal.NewCache() + cache.Init(ctx, errFunc) + cache.SetWriter(hashFunc) + cacheFunc := func(_ context.Context) param.SectionWriter { + return cache + } + + encryptRefFunc := func(_ context.Context) param.SectionWriter { + eFunc, err := New(testKey, uint32(42), cacheFunc) + if err != nil { + t.Fatal(err) + } + eFunc.Init(ctx, errFunc) + return eFunc + } + + h := hasher.New(encryptRefFunc) + + _, data := testutil.SerialData(chunkSize*branches, 255, 0) + for i := 0; i < chunkSize*branches; i += chunkSize { + h.SeekSection(i / chunkSize) + h.Write(data[i : i+chunkSize]) + } + h.SetLength(chunkSize * branches) + ref := h.Sum(nil) + select { + case <-ctx.Done(): + t.Fatalf("ctx done: %v", ctx.Err()) + default: + } + t.Logf("%x", ref) +} diff --git a/file/hasher/common_test.go b/file/hasher/common_test.go new file mode 100644 index 0000000000..fdb7a817d1 --- /dev/null +++ b/file/hasher/common_test.go @@ -0,0 +1,268 @@ +package hasher + +import ( + "bytes" + "context" + "encoding/binary" + "hash" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/testutil" + "golang.org/x/crypto/sha3" +) + +const ( + sectionSize = 32 + branches = 128 + chunkSize = 4096 +) + +var ( + dataLengths = []int{31, // 0 + 32, // 1 + 33, // 2 + 63, // 3 + 64, // 4 + 65, // 5 + chunkSize, // 6 + chunkSize + 31, // 7 + chunkSize + 32, // 8 + chunkSize + 63, // 9 + chunkSize + 64, // 10 + chunkSize * 2, // 11 + chunkSize*2 + 32, // 12 + chunkSize * 128, // 13 + chunkSize*128 + 31, // 14 + chunkSize*128 + 32, // 15 + chunkSize*128 + 64, // 16 + chunkSize * 129, // 17 + chunkSize * 130, // 18 + chunkSize * 128 * 128, // 19 + chunkSize*128*128 + 32, // 20 + } + expected = []string{ + "ece86edb20669cc60d142789d464d57bdf5e33cb789d443f608cbd81cfa5697d", // 0 + "0be77f0bb7abc9cd0abed640ee29849a3072ccfd1020019fe03658c38f087e02", // 1 + "3463b46d4f9d5bfcbf9a23224d635e51896c1daef7d225b86679db17c5fd868e", // 2 + "95510c2ff18276ed94be2160aed4e69c9116573b6f69faaeed1b426fea6a3db8", // 3 + "490072cc55b8ad381335ff882ac51303cc069cbcb8d8d3f7aa152d9c617829fe", // 4 + "541552bae05e9a63a6cb561f69edf36ffe073e441667dbf7a0e9a3864bb744ea", // 5 + "c10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef", // 6 + "91699c83ed93a1f87e326a29ccd8cc775323f9e7260035a5f014c975c5f3cd28", // 7 + "73759673a52c1f1707cbb61337645f4fcbd209cdc53d7e2cedaaa9f44df61285", // 8 + "db1313a727ffc184ae52a70012fbbf7235f551b9f2d2da04bf476abe42a3cb42", // 9 + "ade7af36ac0c7297dc1c11fd7b46981b629c6077bce75300f85b02a6153f161b", // 10 + "29a5fb121ce96194ba8b7b823a1f9c6af87e1791f824940a53b5a7efe3f790d9", // 11 + "61416726988f77b874435bdd89a419edc3861111884fd60e8adf54e2f299efd6", // 12 + "3047d841077898c26bbe6be652a2ec590a5d9bd7cd45d290ea42511b48753c09", // 13 + "e5c76afa931e33ac94bce2e754b1bb6407d07f738f67856783d93934ca8fc576", // 14 + "485a526fc74c8a344c43a4545a5987d17af9ab401c0ef1ef63aefcc5c2c086df", // 15 + "624b2abb7aefc0978f891b2a56b665513480e5dc195b4a66cd8def074a6d2e94", // 16 + "b8e1804e37a064d28d161ab5f256cc482b1423d5cd0a6b30fde7b0f51ece9199", // 17 + "59de730bf6c67a941f3b2ffa2f920acfaa1713695ad5deea12b4a121e5f23fa1", // 18 + "522194562123473dcfd7a457b18ee7dee8b7db70ed3cfa2b73f348a992fdfd3b", // 19 + "ed0cc44c93b14fef2d91ab3a3674eeb6352a42ac2f0bbe524711824aae1e7bcc", // 20 + } + + start = 0 + end = len(dataLengths) +) + +func init() { + testutil.Init() +} + +var ( + dummyHashFunc = func(_ context.Context) param.SectionWriter { + return newDummySectionWriter(chunkSize*branches, sectionSize, sectionSize, branches) + } + + // placeholder for cases where a hasher is not necessary + noHashFunc = func(_ context.Context) param.SectionWriter { + return nil + } + + logErrFunc = func(err error) { + log.Error("SectionWriter pipeline error", "err", err) + } +) + +// simple param.SectionWriter hasher that keeps the data written to it +// for later inspection +// TODO: see if this can be replaced with the fake hasher from storage module +type dummySectionWriter struct { + sectionSize int + digestSize int + branches int + data []byte + digest []byte + size int + span []byte + summed bool + index int + writer hash.Hash + mu sync.Mutex + wg sync.WaitGroup +} + +func newDummySectionWriter(cp int, sectionSize int, digestSize int, branches int) *dummySectionWriter { + return &dummySectionWriter{ + sectionSize: sectionSize, + digestSize: digestSize, + branches: branches, + data: make([]byte, cp), + writer: sha3.NewLegacyKeccak256(), + digest: make([]byte, digestSize), + } +} + +func (d *dummySectionWriter) Init(_ context.Context, _ func(error)) { +} + +func (d *dummySectionWriter) SetWriter(_ param.SectionWriterFunc) param.SectionWriter { + log.Error("dummySectionWriter does not support SectionWriter chaining") + return d +} + +// implements param.SectionWriter +func (d *dummySectionWriter) SeekSection(offset int) { + d.index = offset * d.SectionSize() +} + +// implements param.SectionWriter +func (d *dummySectionWriter) SetLength(length int) { + d.size = length +} + +// implements param.SectionWriter +func (d *dummySectionWriter) SetSpan(length int) { + d.span = make([]byte, 8) + binary.LittleEndian.PutUint64(d.span, uint64(length)) +} + +// implements param.SectionWriter +func (d *dummySectionWriter) Write(data []byte) (int, error) { + d.mu.Lock() + copy(d.data[d.index:], data) + d.size += len(data) + log.Trace("dummywriter write", "index", d.index, "size", d.size, "threshold", d.sectionSize*d.branches) + if d.isFull() { + d.summed = true + d.mu.Unlock() + d.sum() + } else { + d.mu.Unlock() + } + return len(data), nil +} + +// implements param.SectionWriter +func (d *dummySectionWriter) Sum(_ []byte) []byte { + log.Trace("dummy Sumcall", "size", d.size) + d.mu.Lock() + if !d.summed { + d.summed = true + d.mu.Unlock() + d.sum() + } else { + d.mu.Unlock() + } + return d.digest +} + +func (d *dummySectionWriter) sum() { + d.mu.Lock() + defer d.mu.Unlock() + d.writer.Write(d.span) + log.Trace("dummy sum writing span", "span", d.span) + for i := 0; i < d.size; i += d.writer.Size() { + sectionData := d.data[i : i+d.writer.Size()] + log.Trace("dummy sum write", "i", i/d.writer.Size(), "data", hexutil.Encode(sectionData), "size", d.size) + d.writer.Write(sectionData) + } + copy(d.digest, d.writer.Sum(nil)) + log.Trace("dummy sum result", "ref", hexutil.Encode(d.digest)) +} + +// implements param.SectionWriter +func (d *dummySectionWriter) Reset() { + d.mu.Lock() + defer d.mu.Unlock() + d.data = make([]byte, len(d.data)) + d.digest = make([]byte, d.digestSize) + d.size = 0 + d.summed = false + d.span = nil + d.writer.Reset() +} + +// implements param.SectionWriter +func (d *dummySectionWriter) BlockSize() int { + return d.sectionSize +} + +// implements param.SectionWriter +func (d *dummySectionWriter) SectionSize() int { + return d.sectionSize +} + +// implements param.SectionWriter +func (d *dummySectionWriter) Size() int { + return d.sectionSize +} + +// implements param.SectionWriter +func (d *dummySectionWriter) Branches() int { + return d.branches +} + +func (d *dummySectionWriter) isFull() bool { + return d.size == d.sectionSize*d.branches +} + +// TestDummySectionWriter +func TestDummySectionWriter(t *testing.T) { + + w := newDummySectionWriter(chunkSize*2, sectionSize, sectionSize, branches) + w.Reset() + + _, data := testutil.SerialData(sectionSize*2, 255, 0) + + w.SeekSection(branches) + w.Write(data[:sectionSize]) + w.SeekSection(branches + 1) + w.Write(data[sectionSize:]) + if !bytes.Equal(w.data[chunkSize:chunkSize+sectionSize*2], data) { + t.Fatalf("Write double pos %d: expected %x, got %x", chunkSize, w.data[chunkSize:chunkSize+sectionSize*2], data) + } + + correctDigestHex := "0x52eefd0c37895a8845d4a6cf6c6b56980e448376e55eb45717663ab7b3fc8d53" + w.SetLength(chunkSize * 2) + w.SetSpan(chunkSize * 2) + digest := w.Sum(nil) + digestHex := hexutil.Encode(digest) + if digestHex != correctDigestHex { + t.Fatalf("Digest: 2xsectionSize*1; expected %s, got %s", correctDigestHex, digestHex) + } + + w = newDummySectionWriter(chunkSize*2, sectionSize*2, sectionSize*2, branches/2) + w.Reset() + w.SeekSection(branches / 2) + w.Write(data) + if !bytes.Equal(w.data[chunkSize:chunkSize+sectionSize*2], data) { + t.Fatalf("Write double pos %d: expected %x, got %x", chunkSize, w.data[chunkSize:chunkSize+sectionSize*2], data) + } + + correctDigestHex += zeroHex + w.SetLength(chunkSize * 2) + w.SetSpan(chunkSize * 2) + digest = w.Sum(nil) + digestHex = hexutil.Encode(digest) + if digestHex != correctDigestHex { + t.Fatalf("Digest 1xsectionSize*2; expected %s, got %s", correctDigestHex, digestHex) + } +} diff --git a/file/hasher/hasher.go b/file/hasher/hasher.go new file mode 100644 index 0000000000..8ed47403fe --- /dev/null +++ b/file/hasher/hasher.go @@ -0,0 +1,126 @@ +package hasher + +import ( + "context" + "errors" + + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// Hasher is a bmt.SectionWriter that executes the file hashing algorithm on arbitary data +type Hasher struct { + target *target + params *treeParams + index *jobIndex + errFunc func(error) + ctx context.Context + + job *job // current level 1 job being written to + size int + count int +} + +// New creates a new Hasher object using the given sectionSize and branch factor +// hasherFunc is used to create *bmt.Hashers to hash the incoming data +// writerFunc is used as the underlying bmt.SectionWriter for the asynchronous hasher jobs. It may be pipelined to other components with the same interface +// TODO: sectionSize and branches should be inferred from underlying writer, not shared across job and hasher +func New(hashFunc param.SectionWriterFunc) *Hasher { + h := &Hasher{ + target: newTarget(), + index: newJobIndex(9), + params: newTreeParams(hashFunc), + } + h.job = newJob(h.params, h.target, h.index, 1, 0) + return h +} + +func (h *Hasher) SetWriter(hashFunc param.SectionWriterFunc) param.SectionWriter { + h.params = newTreeParams(hashFunc) + return h +} + +// Init implements param.SectionWriter +func (h *Hasher) Init(ctx context.Context, errFunc func(error)) { + h.errFunc = errFunc + h.params.SetContext(ctx) + h.job.start() +} + +// Write implements param.SectionWriter +// It as a non-blocking call that hashes a data chunk and passes the resulting reference to the hash job representing +// the intermediate chunk holding the data references +// TODO: enforce buffered writes and limits +// TODO: attempt omit modulo calc on every pass +// TODO: preallocate full size span slice +func (h *Hasher) Write(b []byte) (int, error) { + if h.count%h.params.Branches == 0 && h.count > 0 { + h.job = h.job.Next() + } + go func(i int, jb *job) { + hasher := h.params.GetWriter() + hasher.SeekSection(-1) + hasher.Write(b) + l := len(b) + log.Trace("data write", "count", i, "size", l) + jb.write(i%h.params.Branches, hasher.Sum(nil)) + h.params.PutWriter(hasher) + }(h.count, h.job) + h.size += len(b) + h.count++ + return len(b), nil +} + +// Sum implements param.SectionWriter +// It is a blocking call that calculates the target level and section index of the received data +// and alerts hasher jobs the end of write is reached +// It returns the root hash +func (h *Hasher) Sum(b []byte) []byte { + sectionCount := dataSizeToSectionIndex(h.size, h.params.SectionSize) + targetLevel := getLevelsFromLength(h.size, h.params.SectionSize, h.params.Branches) + h.target.Set(h.size, sectionCount, targetLevel) + ref := <-h.target.Done() + if b == nil { + return ref + } + return append(b, ref...) +} + +func (h *Hasher) SetSpan(length int) { +} + +func (h *Hasher) SetLength(length int) { + h.size = length +} + +// Seek implements io.Seeker in param.SectionWriter +func (h *Hasher) SeekSection(offset int) { + h.errFunc(errors.New("Hasher cannot seek")) +} + +// Reset implements param.SectionWriter +func (h *Hasher) Reset() { + h.size = 0 + h.count = 0 + h.target = newTarget() + h.job = newJob(h.params, h.target, h.index, 1, 0) +} + +func (h *Hasher) BlockSize() int { + return h.params.ChunkSize +} + +// SectionSize implements param.SectionWriter +func (h *Hasher) SectionSize() int { + return h.params.ChunkSize +} + +// DigestSize implements param.SectionWriter +func (h *Hasher) Size() int { + return h.params.SectionSize +} + +// DigestSize implements param.SectionWriter +func (h *Hasher) Branches() int { + return h.params.Branches +} diff --git a/file/hasher/hasher_test.go b/file/hasher/hasher_test.go new file mode 100644 index 0000000000..11a7a080c8 --- /dev/null +++ b/file/hasher/hasher_test.go @@ -0,0 +1,172 @@ +package hasher + +import ( + "context" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/file/testutillocal" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/testutil" +) + +// TestHasherJobTopHash verifies that the top hash on the first level is correctly set even though the Hasher writes asynchronously to the underlying job +func TestHasherJobTopHash(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(0) + + _, data := testutil.SerialData(chunkSize*branches, 255, 0) + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + var i int + for i = 0; i < chunkSize*branches; i += chunkSize { + h.Write(data[i : i+chunkSize]) + } + h.Sum(nil) + levelOneTopHash := hexutil.Encode(h.index.GetTopHash(1)) + correctLevelOneTopHash := "0xc10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if levelOneTopHash != correctLevelOneTopHash { + t.Fatalf("tophash; expected %s, got %s", correctLevelOneTopHash, levelOneTopHash) + } + +} + +// TestHasherOneFullChunk verifies the result of writing a single data chunk to Hasher +func TestHasherOneFullChunk(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(0) + + _, data := testutil.SerialData(chunkSize*branches, 255, 0) + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + var i int + for i = 0; i < chunkSize*branches; i += chunkSize { + h.Write(data[i : i+chunkSize]) + } + ref := h.Sum(nil) + correctRootHash := "0x3047d841077898c26bbe6be652a2ec590a5d9bd7cd45d290ea42511b48753c09" + rootHash := hexutil.Encode(ref) + if rootHash != correctRootHash { + t.Fatalf("roothash; expected %s, got %s", correctRootHash, rootHash) + } +} + +// TestHasherOneFullChunk verifies that Hasher creates new jobs on branch thresholds +func TestHasherJobChange(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(0) + + _, data := testutil.SerialData(chunkSize*branches*branches, 255, 0) + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + jobs := make(map[string]int) + for i := 0; i < chunkSize*branches*branches; i += chunkSize { + h.Write(data[i : i+chunkSize]) + jobs[h.job.String()]++ + } + i := 0 + for _, v := range jobs { + if v != branches { + t.Fatalf("jobwritecount writes: expected %d, got %d", branches, v) + } + i++ + } + if i != branches { + t.Fatalf("jobwritecount jobs: expected %d, got %d", branches, i) + } +} + +// TestHasherONeFullLevelOneChunk verifies the result of writing branches times data chunks to Hasher +func TestHasherOneFullLevelOneChunk(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(128) + + _, data := testutil.SerialData(chunkSize*branches*branches, 255, 0) + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + var i int + for i = 0; i < chunkSize*branches*branches; i += chunkSize { + h.Write(data[i : i+chunkSize]) + } + ref := h.Sum(nil) + correctRootHash := "0x522194562123473dcfd7a457b18ee7dee8b7db70ed3cfa2b73f348a992fdfd3b" + rootHash := hexutil.Encode(ref) + if rootHash != correctRootHash { + t.Fatalf("roothash; expected %s, got %s", correctRootHash, rootHash) + } +} + +func TestHasherVector(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(128) + + var mismatch int + for i, dataLength := range dataLengths { + log.Info("hashervector start", "i", i, "l", dataLength) + eq := true + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + _, data := testutil.SerialData(dataLength, 255, 0) + for j := 0; j < dataLength; j += chunkSize { + size := chunkSize + if dataLength-j < chunkSize { + size = dataLength - j + } + h.Write(data[j : j+size]) + } + //h.SetLength(dataLength) + ref := h.Sum(nil) + correctRefHex := "0x" + expected[i] + refHex := hexutil.Encode(ref) + if refHex != correctRefHex { + mismatch++ + eq = false + } + t.Logf("[%7d+%4d]\t%v\tref: %x\texpect: %s", dataLength/chunkSize, dataLength%chunkSize, eq, ref, expected[i]) + } + if mismatch > 0 { + t.Fatalf("mismatches: %d/%d", mismatch, end-start) + } +} + +// BenchmarkHasher generates benchmarks that are comparable to the pyramid hasher +func BenchmarkHasher(b *testing.B) { + for i := start; i < end; i++ { + b.Run(fmt.Sprintf("%d/%d", i, dataLengths[i]), benchmarkHasher) + } +} + +func benchmarkHasher(b *testing.B) { + params := strings.Split(b.Name(), "/") + dataLengthParam, err := strconv.ParseInt(params[2], 10, 64) + if err != nil { + b.Fatal(err) + } + dataLength := int(dataLengthParam) + + hashFunc := testutillocal.NewBMTHasherFunc(128) + _, data := testutil.SerialData(dataLength, 255, 0) + + for j := 0; j < b.N; j++ { + h := New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, logErrFunc) + for i := 0; i < dataLength; i += chunkSize { + size := chunkSize + if dataLength-i < chunkSize { + size = dataLength - i + } + h.Write(data[i : i+size]) + } + h.Sum(nil) + } +} diff --git a/file/hasher/index.go b/file/hasher/index.go new file mode 100644 index 0000000000..9e8ab21071 --- /dev/null +++ b/file/hasher/index.go @@ -0,0 +1,82 @@ +package hasher + +import ( + "fmt" + "sync" +) + +// keeps an index of all the existing jobs for a file hashing operation +// sorted by level +// +// it also keeps all the "top hashes", ie hashes on first data section index of every level +// these are needed in case of balanced tree results, since the hashing result would be +// lost otherwise, due to the job not having any intermediate storage of any data +type jobIndex struct { + maxLevels int + jobs []sync.Map + topHashes [][]byte + mu sync.Mutex +} + +func newJobIndex(maxLevels int) *jobIndex { + ji := &jobIndex{ + maxLevels: maxLevels, + } + for i := 0; i < maxLevels; i++ { + ji.jobs = append(ji.jobs, sync.Map{}) + } + return ji +} + +// implements Stringer interface +func (ji *jobIndex) String() string { + return fmt.Sprintf("%p", ji) +} + +// Add adds a job to the index at the level +// and data section index specified in the job +func (ji *jobIndex) Add(jb *job) { + //log.Trace("adding job", "job", jb) + ji.jobs[jb.level].Store(jb.dataSection, jb) +} + +// Get retrieves a job from the job index +// based on the level of the job and its data section index +// if a job for the level and section index does not exist this method returns nil +func (ji *jobIndex) Get(lvl int, section int) *job { + jb, ok := ji.jobs[lvl].Load(section) + if !ok { + return nil + } + return jb.(*job) +} + +// Delete removes a job from the job index +// leaving it to be garbage collected when +// the reference in the main code is relinquished +func (ji *jobIndex) Delete(jb *job) { + ji.jobs[jb.level].Delete(jb.dataSection) +} + +// AddTopHash should be called by a job when a hash is written to the first index of a level +// since the job doesn't store any data written to it (just passing it through to the underlying writer) +// this is needed for the edge case of balanced trees +func (ji *jobIndex) AddTopHash(ref []byte) { + ji.mu.Lock() + defer ji.mu.Unlock() + ji.topHashes = append(ji.topHashes, ref) + //log.Trace("added top hash", "length", len(ji.topHashes), "index", ji) +} + +// GetJobHash gets the current top hash for a particular level set by AddTopHash +func (ji *jobIndex) GetTopHash(lvl int) []byte { + ji.mu.Lock() + defer ji.mu.Unlock() + return ji.topHashes[lvl-1] +} + +func (ji *jobIndex) GetTopHashLevel() int { + ji.mu.Lock() + defer ji.mu.Unlock() + return len(ji.topHashes) +} diff --git a/file/hasher/job.go b/file/hasher/job.go new file mode 100644 index 0000000000..1ada233210 --- /dev/null +++ b/file/hasher/job.go @@ -0,0 +1,327 @@ +package hasher + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// necessary metadata across asynchronous input +type jobUnit struct { + index int + data []byte + count int +} + +// encapsulates one single intermediate chunk to be hashed +type job struct { + target *target + params *treeParams + index *jobIndex + + level int // level in tree + dataSection int // data section index + cursorSection int32 // next write position in job + endCount int32 // number of writes to be written to this job (0 means write to capacity) + lastSectionSize int // data size on the last data section write + firstSectionData []byte // store first section of data written to solve the dangling chunk edge case + + writeC chan jobUnit + writer param.SectionWriter // underlying data processor + doneC chan struct{} // pointer to target doneC channel, set to nil in process() when closed + + mu sync.Mutex +} + +func newJob(params *treeParams, tgt *target, jobIndex *jobIndex, lvl int, dataSection int) *job { + jb := &job{ + params: params, + index: jobIndex, + level: lvl, + dataSection: dataSection, + writeC: make(chan jobUnit), + target: tgt, + doneC: nil, + } + if jb.index == nil { + jb.index = newJobIndex(9) + } + targetLevel := tgt.Level() + if targetLevel == 0 { + log.Trace("target not set", "level", lvl) + jb.doneC = tgt.doneC + + } else { + targetCount := tgt.Count() + jb.endCount = int32(jb.targetCountToEndCount(targetCount)) + } + log.Trace("target count", "level", lvl, "count", tgt.Count()) + + jb.index.Add(jb) + return jb +} + +func (jb *job) start() { + jb.writer = jb.params.GetWriter() + go jb.process() +} + +// implements Stringer interface +func (jb *job) String() string { + return fmt.Sprintf("job: l:%d,s:%d", jb.level, jb.dataSection) +} + +// atomically increments the write counter of the job +func (jb *job) inc() int { + return int(atomic.AddInt32(&jb.cursorSection, 1)) +} + +// atomically returns the write counter of the job +func (jb *job) count() int { + return int(atomic.LoadInt32(&jb.cursorSection)) +} + +// size returns the byte size of the span the job represents +// if job is last index in a level and writes have been finalized, it will return the target size +// otherwise, regardless of job index, it will return the size according to the current write count +// TODO: returning expected size in one case and actual size in another can lead to confusion +func (jb *job) size() int { + jb.mu.Lock() + count := int(jb.cursorSection) //jb.count() + endCount := int(jb.endCount) //int(atomic.LoadInt32(&jb.endCount)) + jb.mu.Unlock() + if endCount%jb.params.Branches == 0 { + return count * jb.params.SectionSize * jb.params.Spans[jb.level] + } + log.Trace("size", "sections", jb.target.sections, "size", jb.target.Size(), "endcount", endCount, "level", jb.level) + return jb.target.Size() % (jb.params.Spans[jb.level] * jb.params.SectionSize * jb.params.Branches) +} + +// add data to job +// does no checking for data length or index validity +// TODO: rename index param not to confuse with index object +func (jb *job) write(index int, data []byte) { + + jb.inc() + + // if a write is received at the first datasection of a level we need to store this hash + // in case of a balanced tree and we need to send it to resultC later + // at the time of hasing of a balanced tree we have no way of knowing for sure whether + // that is the end of the job or not + if jb.dataSection == 0 && index == 0 { + topHashLevel := jb.index.GetTopHashLevel() + if topHashLevel < jb.level { + log.Trace("have tophash", "level", jb.level, "ref", hexutil.Encode(data)) + jb.index.AddTopHash(data) + } + } + jb.writeC <- jobUnit{ + index: index, + data: data, + } +} + +// runs in loop until: +// - sectionSize number of job writes have occurred (one full chunk) +// - data write is finalized and targetcount for this chunk was already reached +// - data write is finalized and targetcount is reached on a subsequent job write +func (jb *job) process() { + + log.Trace("starting job process", "level", jb.level, "sec", jb.dataSection) + + var processCount int + defer jb.destroy() + + // is set when data write is finished, AND + // the final data section falls within the span of this job + // if not, loop will only exit on Branches writes +OUTER: + for { + select { + + // enter here if new data is written to the job + // TODO: Error if calculated write count exceed chunk + case entry := <-jb.writeC: + + // split the contents to fit the underlying SectionWriter + entrySections := len(entry.data) / jb.writer.SectionSize() + jb.mu.Lock() + endCount := int(jb.endCount) + oldProcessCount := processCount + processCount += entrySections + jb.mu.Unlock() + if entry.index == 0 { + jb.firstSectionData = entry.data + } + log.Trace("job entry", "datasection", jb.dataSection, "num sections", entrySections, "level", jb.level, "processCount", oldProcessCount, "endcount", endCount, "index", entry.index, "data", hexutil.Encode(entry.data)) + + // TODO: this write is superfluous when the received data is the root hash + var offset int + for i := 0; i < entrySections; i++ { + idx := entry.index + i + data := entry.data[offset : offset+jb.writer.SectionSize()] + log.Trace("job write", "datasection", jb.dataSection, "level", jb.level, "processCount", oldProcessCount+i, "endcount", endCount, "index", entry.index+i, "data", hexutil.Encode(data)) + jb.writer.SeekSection(idx) + jb.writer.Write(data) + offset += jb.writer.SectionSize() + } + + // since newcount is incremented above it can only equal endcount if this has been set in the case below, + // which means data write has been completed + // otherwise if we reached the chunk limit we also continue to hashing + if processCount == endCount { + log.Trace("quitting writec - endcount", "c", processCount, "level", jb.level) + break OUTER + } + if processCount == jb.writer.Branches() { + log.Trace("quitting writec - branches") + break OUTER + } + + // enter here if data writes have been completed + // TODO: this case currently executes for all cycles after data write is complete for which writes to this job do not happen. perhaps it can be improved + case <-jb.doneC: + jb.mu.Lock() + jb.doneC = nil + log.Trace("doneloop", "level", jb.level, "processCount", processCount, "endcount", jb.endCount) + //count := jb.count() + + // if the target count falls within the span of this job + // set the endcount so we know we have to do extra calculations for + // determining span in case of unbalanced tree + targetCount := jb.target.Count() + jb.endCount = int32(jb.targetCountToEndCount(targetCount)) + log.Trace("doneloop done", "level", jb.level, "targetcount", jb.target.Count(), "endcount", jb.endCount) + + // if we have reached the end count for this chunk, we proceed to hashing + // this case is important when write to the level happen after this goroutine + // registers that data writes have been completed + if processCount > 0 && processCount == int(jb.endCount) { + log.Trace("quitting donec", "level", jb.level, "processcount", processCount) + jb.mu.Unlock() + break OUTER + } + jb.mu.Unlock() + } + } + + jb.sum() +} + +func (jb *job) sum() { + + targetLevel := jb.target.Level() + if targetLevel == jb.level { + jb.target.resultC <- jb.index.GetTopHash(jb.level) + return + } + + // get the size of the span and execute the hash digest of the content + size := jb.size() + //span := bmt.LengthToSpan(size) + refSize := jb.count() * jb.params.SectionSize + jb.writer.SetLength(refSize) + jb.writer.SetSpan(size) + log.Trace("job sum", "count", jb.count(), "refsize", refSize, "size", size, "datasection", jb.dataSection, "level", jb.level, "targetlevel", targetLevel, "endcount", jb.endCount) + ref := jb.writer.Sum(nil) + + // endCount > 0 means this is the last chunk on the level + // the hash from the level below the target level will be the result + belowRootLevel := targetLevel - 1 + if jb.endCount > 0 && jb.level == belowRootLevel { + jb.target.resultC <- ref + return + } + + // retrieve the parent and the corresponding section in it to write to + parent := jb.parent() + log.Trace("have parent", "level", jb.level, "jb p", fmt.Sprintf("%p", jb), "jbp p", fmt.Sprintf("%p", parent)) + nextLevel := jb.level + 1 + parentSection := dataSectionToLevelSection(jb.params, nextLevel, jb.dataSection) + + // in the event that we have a balanced tree and a chunk with single reference below the target level + // we move the single reference up to the penultimate level + if jb.endCount == 1 { + ref = jb.firstSectionData + for parent.level < belowRootLevel { + log.Trace("parent write skip", "level", parent.level) + oldParent := parent + parent = parent.parent() + oldParent.destroy() + nextLevel += 1 + parentSection = dataSectionToLevelSection(jb.params, nextLevel, jb.dataSection) + } + } + parent.write(parentSection, ref) + +} + +// determine whether the given data section count falls within the span of the current job +func (jb *job) targetWithinJob(targetSection int) (int, bool) { + var endIndex int + var ok bool + + // span one level above equals the data size of 128 units of one section on this level + // using the span table saves one multiplication + //dataBoundary := dataSectionToLevelBoundary(jb.params, jb.level, jb.dataSection) + dataBoundary := dataSectionToLevelBoundary(jb.params, jb.level, jb.dataSection) + upperLimit := dataBoundary + jb.params.Spans[jb.level+1] + + // the data section is the data section index where the span of this job starts + if targetSection >= dataBoundary && targetSection < upperLimit { + + // data section index must be divided by corresponding section size on the job's level + // then wrap on branch period to find the correct section within this job + endIndex = (targetSection / jb.params.Spans[jb.level]) % jb.params.Branches + + ok = true + } + return endIndex, ok +} + +// if last data index falls within the span, return the appropriate end count for the level +// otherwise return 0 (which means job write until limit) +func (jb *job) targetCountToEndCount(targetCount int) int { + endIndex, ok := jb.targetWithinJob(targetCount - 1) + if !ok { + return 0 + } + return endIndex + 1 +} + +// returns the parent job of the receiver job +// a new parent job is created if none exists for the slot +func (jb *job) parent() *job { + jb.index.mu.Lock() + defer jb.index.mu.Unlock() + newLevel := jb.level + 1 + // Truncate to even quotient which is the actual logarithmic boundary of the data section under the span + newDataSection := dataSectionToLevelBoundary(jb.params, jb.level+1, jb.dataSection) + parent := jb.index.Get(newLevel, newDataSection) + if parent != nil { + return parent + } + jbp := newJob(jb.params, jb.target, jb.index, jb.level+1, newDataSection) + jbp.start() + return jbp +} + +// Next creates the job for the next data section span on the same level as the receiver job +// this is only meant to be called once for each job, consecutive calls will overwrite index with new empty job +func (jb *job) Next() *job { + jbn := newJob(jb.params, jb.target, jb.index, jb.level, jb.dataSection+jb.params.Spans[jb.level+1]) + jbn.start() + return jbn +} + +// cleans up the job; reset hasher and remove pointer to job from index +func (jb *job) destroy() { + if jb.writer != nil { + jb.params.PutWriter(jb.writer) + } + jb.index.Delete(jb) +} diff --git a/file/hasher/job_test.go b/file/hasher/job_test.go new file mode 100644 index 0000000000..6d4cbaa157 --- /dev/null +++ b/file/hasher/job_test.go @@ -0,0 +1,654 @@ +package hasher + +import ( + "context" + "fmt" + "math/rand" + "strconv" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/file/testutillocal" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/testutil" + "golang.org/x/crypto/sha3" +) + +const ( + zeroHex = "0000000000000000000000000000000000000000000000000000000000000000" +) + +// TestTreeParams verifies that params are set correctly by the param constructor +func TestTreeParams(t *testing.T) { + + params := newTreeParams(dummyHashFunc) + + if params.SectionSize != 32 { + t.Fatalf("section: expected %d, got %d", sectionSize, params.SectionSize) + } + + if params.Branches != 128 { + t.Fatalf("branches: expected %d, got %d", branches, params.SectionSize) + } + + if params.Spans[2] != branches*branches { + t.Fatalf("span %d: expected %d, got %d", 2, branches*branches, params.Spans[1]) + } + +} + +// TestTarget verifies that params are set correctly by the target constructor +func TestTarget(t *testing.T) { + + tgt := newTarget() + tgt.Set(32, 1, 2) + + if tgt.size != 32 { + t.Fatalf("target size expected %d, got %d", 32, tgt.size) + } + + if tgt.sections != 1 { + t.Fatalf("target sections expected %d, got %d", 1, tgt.sections) + } + + if tgt.level != 2 { + t.Fatalf("target level expected %d, got %d", 2, tgt.level) + } +} + +// TestJobTargetWithinJobDefault verifies the calculation of whether a final data section index +// falls within a particular job's span without regard to differing SectionSize +func TestJobTargetWithinDefault(t *testing.T) { + params := newTreeParams(dummyHashFunc) + index := newJobIndex(9) + tgt := newTarget() + + jb := newJob(params, tgt, index, 1, branches*branches) + defer jb.destroy() + + finalSize := chunkSize*branches + chunkSize*2 + finalCount := dataSizeToSectionCount(finalSize, sectionSize) + log.Trace("within test", "size", finalSize, "count", finalCount) + c, ok := jb.targetWithinJob(finalCount - 1) + if !ok { + t.Fatalf("target %d within %d: expected true", finalCount, jb.level) + } + if c != 1 { + t.Fatalf("target %d within %d: expected %d, got %d", finalCount, jb.level, 2, c) + } +} + +// TestJobTargetWithinDifferentSections does the same as TestTargetWithinJobDefault but +// with SectionSize/Branches settings differeing between client target and underlying writer +func TestJobTargetWithinDifferentSections(t *testing.T) { + dummyHashDoubleFunc := func(_ context.Context) param.SectionWriter { + return newDummySectionWriter(chunkSize, sectionSize*2, sectionSize*2, branches/2) + } + params := newTreeParams(dummyHashDoubleFunc) + index := newJobIndex(9) + tgt := newTarget() + + //jb := newJob(params, tgt, index, 1, branches*branches) + jb := newJob(params, tgt, index, 1, 0) + defer jb.destroy() + + //finalSize := chunkSize*branches + chunkSize*2 + finalSize := chunkSize + finalCount := dataSizeToSectionCount(finalSize, sectionSize) + log.Trace("within test", "size", finalSize, "count", finalCount) + c, ok := jb.targetWithinJob(finalCount - 1) + if !ok { + t.Fatalf("target %d within %d: expected true", finalCount, jb.level) + } + if c != 1 { + t.Fatalf("target %d within %d: expected %d, got %d", finalCount, jb.level, 1, c) + } +} + +// TestNewJob verifies that a job is initialized with the correct values +func TestNewJob(t *testing.T) { + + params := newTreeParams(dummyHashFunc) + params.Debug = true + + tgt := newTarget() + jb := newJob(params, tgt, nil, 1, branches*branches+1) + if jb.level != 1 { + t.Fatalf("job level expected 1, got %d", jb.level) + } + if jb.dataSection != branches*branches+1 { + t.Fatalf("datasectionindex: expected %d, got %d", branches+1, jb.dataSection) + } + tgt.Set(0, 0, 0) + jb.destroy() +} + +// TestJobSize verifies the data size calculation used for calculating the span of data +// under a particular level reference +// it tests both a balanced and an unbalanced tree +func TestJobSize(t *testing.T) { + params := newTreeParams(dummyHashFunc) + params.Debug = true + index := newJobIndex(9) + + tgt := newTarget() + jb := newJob(params, tgt, index, 3, 0) + jb.cursorSection = 1 + jb.endCount = 1 + size := chunkSize*branches + chunkSize + sections := dataSizeToSectionIndex(size, sectionSize) + 1 + tgt.Set(size, sections, 3) + jobSize := jb.size() + if jobSize != size { + t.Fatalf("job size: expected %d, got %d", size, jobSize) + } + jb.destroy() + + tgt = newTarget() + jb = newJob(params, tgt, index, 3, 0) + jb.cursorSection = 1 + jb.endCount = 1 + size = chunkSize * branches * branches + sections = dataSizeToSectionIndex(size, sectionSize) + 1 + tgt.Set(size, sections, 3) + jobSize = jb.size() + if jobSize != size { + t.Fatalf("job size: expected %d, got %d", size, jobSize) + } + jb.destroy() + +} + +// TestJobTarget verifies that the underlying calculation for determining whether +// a data section index is within a level's span is correct +func TestJobTarget(t *testing.T) { + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + params.Debug = true + index := newJobIndex(9) + + jb := newJob(params, tgt, index, 1, branches*branches) + + // this is less than chunksize * 128 + // it will not be in the job span + finalSize := chunkSize + sectionSize + 1 + finalSection := dataSizeToSectionIndex(finalSize, sectionSize) + c, ok := jb.targetWithinJob(finalSection) + if ok { + t.Fatalf("targetwithinjob: expected false") + } + jb.destroy() + + // chunkSize*128+chunkSize*2 (532480) is within chunksize*128 (524288) and chunksize*128*2 (1048576) + // it will be within the job span + finalSize = chunkSize*branches + chunkSize*2 + finalSection = dataSizeToSectionIndex(finalSize, sectionSize) + c, ok = jb.targetWithinJob(finalSection) + if !ok { + t.Fatalf("targetwithinjob section %d: expected true", branches*branches) + } + if c != 1 { + t.Fatalf("targetwithinjob section %d: expected %d, got %d", branches*branches, 1, c) + } + c = jb.targetCountToEndCount(finalSection + 1) + if c != 2 { + t.Fatalf("targetcounttoendcount section %d: expected %d, got %d", branches*branches, 2, c) + } + jb.destroy() +} + +// TestJobIndex verifies that the job constructor adds the job to the job index +// and removes it on job destruction +func TestJobIndex(t *testing.T) { + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + + jb := newJob(params, tgt, nil, 1, branches) + jobIndex := jb.index + jbGot := jobIndex.Get(1, branches) + if jb != jbGot { + t.Fatalf("jbIndex get: expect %p, got %p", jb, jbGot) + } + jbGot.destroy() + if jobIndex.Get(1, branches) != nil { + t.Fatalf("jbIndex delete: expected nil") + } +} + +// TestJobGetNext verifies that the new job constructed through the job.Next() method +// has the correct level and data section index +func TestJobGetNext(t *testing.T) { + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + params.Debug = true + + jb := newJob(params, tgt, nil, 1, branches*branches) + jbn := jb.Next() + if jbn == nil { + t.Fatalf("parent: nil") + } + if jbn.level != 1 { + t.Fatalf("nextjob level: expected %d, got %d", 2, jbn.level) + } + if jbn.dataSection != jb.dataSection+branches*branches { + t.Fatalf("nextjob section: expected %d, got %d", jb.dataSection+branches*branches, jbn.dataSection) + } +} + +// TestJobWriteTwoAndFinish writes two references to a job and sets the job target to two chunks +// it verifies that the job count after the writes is two, and the hash is correct +func TestJobWriteTwoAndFinish(t *testing.T) { + + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + _, data := testutil.SerialData(sectionSize*2, 255, 0) + jb.write(0, data[:sectionSize]) + jb.write(1, data[sectionSize:]) + + finalSize := chunkSize * 2 + finalSection := dataSizeToSectionIndex(finalSize, sectionSize) + tgt.Set(finalSize, finalSection-1, 2) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*199) + defer cancel() + select { + case ref := <-tgt.Done(): + correctRefHex := "0xe1553e1a3a6b73f96e6fc48318895e401e7db2972962ee934633fa8b3eaaf78b" + refHex := hexutil.Encode(ref) + if refHex != correctRefHex { + t.Fatalf("job write two and finish: expected %s, got %s", correctRefHex, refHex) + } + case <-ctx.Done(): + t.Fatalf("timeout: %v", ctx.Err()) + } + + if jb.count() != 2 { + t.Fatalf("jobcount: expected %d, got %d", 2, jb.count()) + } +} + +// TestJobGetParent verifies that the parent returned from two jobs' parent() calls +// that are within the same span as the parent chunk of references is the same +// BUG: not guaranteed to return same parent when run with eg -count 100 +func TestJobGetParent(t *testing.T) { + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + + jb := newJob(params, tgt, nil, 1, branches*branches) + jb.start() + jbp := jb.parent() + if jbp == nil { + t.Fatalf("parent: nil") + } + if jbp.level != 2 { + t.Fatalf("parent level: expected %d, got %d", 2, jbp.level) + } + if jbp.dataSection != 0 { + t.Fatalf("parent data section: expected %d, got %d", 0, jbp.dataSection) + } + jbGot := jb.index.Get(2, 0) + if jbGot == nil { + t.Fatalf("index get: nil") + } + + jbNext := jb.Next() + jbpNext := jbNext.parent() + if jbpNext != jbp { + t.Fatalf("next parent: expected %p, got %p", jbp, jbpNext) + } +} + +// TestJobWriteParentSection verifies that a data write translates to a write +// in the correct section of its parent +func TestJobWriteParentSection(t *testing.T) { + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + index := newJobIndex(9) + + jb := newJob(params, tgt, index, 1, 0) + jbn := jb.Next() + _, data := testutil.SerialData(sectionSize*2, 255, 0) + jbn.write(0, data[:sectionSize]) + jbn.write(1, data[sectionSize:]) + + finalSize := chunkSize*branches + chunkSize*2 + finalSection := dataSizeToSectionIndex(finalSize, sectionSize) + tgt.Set(finalSize, finalSection, 3) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + select { + case <-tgt.Done(): + t.Fatalf("unexpected done") + case <-ctx.Done(): + } + jbnp := jbn.parent() + if jbnp.count() != 1 { + t.Fatalf("parent count: expected %d, got %d", 1, jbnp.count()) + } + correctRefHex := "0xe1553e1a3a6b73f96e6fc48318895e401e7db2972962ee934633fa8b3eaaf78b" + + // extract data in section 2 from the writer + // TODO: overload writer to provide a get method to extract data to improve clarity + w := jbnp.writer.(*dummySectionWriter) + w.mu.Lock() + parentRef := w.data[32:64] + w.mu.Unlock() + parentRefHex := hexutil.Encode(parentRef) + if parentRefHex != correctRefHex { + t.Fatalf("parent data: expected %s, got %s", correctRefHex, parentRefHex) + } +} + +// TestJobWriteFull verifies the hashing result of the write of a balanced tree +// where the simulated tree is chunkSize*branches worth of data +func TestJobWriteFull(t *testing.T) { + + tgt := newTarget() + params := newTreeParams(dummyHashFunc) + + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + _, data := testutil.SerialData(chunkSize, 255, 0) + for i := 0; i < chunkSize; i += sectionSize { + jb.write(i/sectionSize, data[i:i+sectionSize]) + } + + tgt.Set(chunkSize, branches, 2) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + select { + case ref := <-tgt.Done(): + correctRefHex := "0x8ace4673563b86281778b943aa60481fc4ede9f238dd98f1b3a5df4cb54ee79b" + refHex := hexutil.Encode(ref) + if refHex != correctRefHex { + t.Fatalf("job write full: expected %s, got %s", correctRefHex, refHex) + } + case <-ctx.Done(): + t.Fatalf("timeout: %v", ctx.Err()) + } + if jb.count() != branches { + t.Fatalf("jobcount: expected %d, got %d", 32, jb.count()) + } +} + +// TestJobWriteSpan uses the bmt asynchronous hasher +// it verifies that a result can be attained at chunkSize+sectionSize*2 references +// which translates to chunkSize*branches+chunkSize*2 bytes worth of data +func TestJobWriteSpan(t *testing.T) { + + tgt := newTarget() + hashFunc := testutillocal.NewBMTHasherFunc(0) + params := newTreeParams(hashFunc) + + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + _, data := testutil.SerialData(chunkSize+sectionSize*2, 255, 0) + + for i := 0; i < chunkSize; i += sectionSize { + jb.write(i/sectionSize, data[i:i+sectionSize]) + } + jbn := jb.Next() + jbn.write(0, data[chunkSize:chunkSize+sectionSize]) + jbn.write(1, data[chunkSize+sectionSize:]) + finalSize := chunkSize*branches + chunkSize*2 + finalSection := dataSizeToSectionIndex(finalSize, sectionSize) + tgt.Set(finalSize, finalSection, 3) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1000) + defer cancel() + select { + case ref := <-tgt.Done(): + // TODO: double check that this hash if correct!! + refCorrectHex := "0xee56134cab34a5a612648dcc22d88b7cb543081bd144906dfc4fa93802c9addf" + refHex := hexutil.Encode(ref) + if refHex != refCorrectHex { + t.Fatalf("writespan sequential: expected %s, got %s", refCorrectHex, refHex) + } + case <-ctx.Done(): + t.Fatalf("timeout: %v", ctx.Err()) + } + + sz := jb.size() + if sz != chunkSize*branches { + t.Fatalf("job 1 size: expected %d, got %d", chunkSize, sz) + } + + sz = jbn.size() + if sz != chunkSize*2 { + t.Fatalf("job 2 size: expected %d, got %d", sectionSize, sz) + } +} + +// TestJobWriteSpanShuffle does the same as TestJobWriteSpan but +// shuffles the indices of the first chunk write +// verifying that sequential use of the underlying hasher is not required +func TestJobWriteSpanShuffle(t *testing.T) { + + tgt := newTarget() + hashFunc := testutillocal.NewBMTHasherFunc(0) + params := newTreeParams(hashFunc) + + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + _, data := testutil.SerialData(chunkSize+sectionSize*2, 255, 0) + + var idxs []int + for i := 0; i < branches; i++ { + idxs = append(idxs, i) + } + rand.Shuffle(branches, func(i int, j int) { + idxs[i], idxs[j] = idxs[j], idxs[i] + }) + for _, idx := range idxs { + jb.write(idx, data[idx*sectionSize:idx*sectionSize+sectionSize]) + } + + jbn := jb.Next() + jbn.write(0, data[chunkSize:chunkSize+sectionSize]) + jbn.write(1, data[chunkSize+sectionSize:]) + finalSize := chunkSize*branches + chunkSize*2 + finalSection := dataSizeToSectionIndex(finalSize, sectionSize) + tgt.Set(finalSize, finalSection, 3) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + select { + case ref := <-tgt.Done(): + refCorrectHex := "0xee56134cab34a5a612648dcc22d88b7cb543081bd144906dfc4fa93802c9addf" + refHex := hexutil.Encode(ref) + jbparent := jb.parent() + jbnparent := jbn.parent() + log.Info("succeeding", "jb count", jb.count(), "jbn count", jbn.count(), "jb parent count", jbparent.count(), "jbn parent count", jbnparent.count()) + if refHex != refCorrectHex { + t.Fatalf("writespan sequential: expected %s, got %s", refCorrectHex, refHex) + } + case <-ctx.Done(): + + jbparent := jb.parent() + jbnparent := jbn.parent() + log.Error("failing", "jb count", jb.count(), "jbn count", jbn.count(), "jb parent count", jbparent.count(), "jbn parent count", jbnparent.count(), "jb parent p", fmt.Sprintf("%p", jbparent), "jbn parent p", fmt.Sprintf("%p", jbnparent)) + t.Fatalf("timeout: %v", ctx.Err()) + } + + sz := jb.size() + if sz != chunkSize*branches { + t.Fatalf("job size: expected %d, got %d", chunkSize*branches, sz) + } + + sz = jbn.size() + if sz != chunkSize*2 { + t.Fatalf("job size: expected %d, got %d", chunkSize*branches, sz) + } +} + +func TestJobWriteDoubleSection(t *testing.T) { + writeSize := sectionSize * 2 + dummyHashDoubleFunc := func(_ context.Context) param.SectionWriter { + return newDummySectionWriter(chunkSize, sectionSize*2, sectionSize*2, branches/2) + } + params := newTreeParams(dummyHashDoubleFunc) + + tgt := newTarget() + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + _, data := testutil.SerialData(chunkSize, 255, 0) + + for i := 0; i < chunkSize; i += writeSize { + jb.write(i/writeSize, data[i:i+writeSize]) + } + tgt.Set(chunkSize, branches/2-1, 2) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + select { + case refLong := <-tgt.Done(): + refLongHex := hexutil.Encode(refLong) + correctRefLongHex := "0x8ace4673563b86281778b943aa60481fc4ede9f238dd98f1b3a5df4cb54ee79b" + zeroHex + if refLongHex != correctRefLongHex { + t.Fatalf("section long: expected %s, got %s", correctRefLongHex, refLongHex) + } + case <-ctx.Done(): + t.Fatalf("timeout: %v", ctx.Err()) + } + +} + +// TestVectors executes the barebones functionality of the hasher +// and verifies against source of truth results generated from the reference hasher +// for the same data +// TODO: vet dynamically against the referencefilehasher instead of expect vector +func TestJobVector(t *testing.T) { + poolSync := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + dataHash := bmt.New(poolSync) + hashFunc := testutillocal.NewBMTHasherFunc(0) + params := newTreeParams(hashFunc) + var mismatch int + + for i := start; i < end; i++ { + tgt := newTarget() + dataLength := dataLengths[i] + _, data := testutil.SerialData(dataLength, 255, 0) + jb := newJob(params, tgt, nil, 1, 0) + jb.start() + count := 0 + log.Info("test vector", "length", dataLength) + for i := 0; i < dataLength; i += chunkSize { + ie := i + chunkSize + if ie > dataLength { + ie = dataLength + } + writeSize := ie - i + dataHash.Reset() + dataHash.SetLength(writeSize) + c, err := dataHash.Write(data[i:ie]) + if err != nil { + jb.destroy() + t.Fatalf("data ref fail: %v", err) + } + if c != ie-i { + jb.destroy() + t.Fatalf("data ref short write: expect %d, got %d", ie-i, c) + } + ref := dataHash.Sum(nil) + log.Debug("data ref", "i", i, "ie", ie, "data", hexutil.Encode(ref)) + jb.write(count, ref) + count += 1 + if ie%(chunkSize*branches) == 0 { + jb = jb.Next() + count = 0 + } + } + dataSections := dataSizeToSectionIndex(dataLength, params.SectionSize) + tgt.Set(dataLength, dataSections, getLevelsFromLength(dataLength, params.SectionSize, params.Branches)) + eq := true + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1000) + defer cancel() + select { + case ref := <-tgt.Done(): + refCorrectHex := "0x" + expected[i] + refHex := hexutil.Encode(ref) + if refHex != refCorrectHex { + mismatch++ + eq = false + } + t.Logf("[%7d+%4d]\t%v\tref: %x\texpect: %s", dataLength/chunkSize, dataLength%chunkSize, eq, ref, expected[i]) + case <-ctx.Done(): + t.Fatalf("timeout: %v", ctx.Err()) + } + + } + if mismatch > 0 { + t.Fatalf("mismatches: %d/%d", mismatch, end-start) + } +} + +// BenchmarkVector generates benchmarks that are comparable to the pyramid hasher +func BenchmarkJob(b *testing.B) { + for i := start; i < end; i++ { + b.Run(fmt.Sprintf("%d/%d", i, dataLengths[i]), benchmarkJob) + } +} + +func benchmarkJob(b *testing.B) { + params := strings.Split(b.Name(), "/") + dataLengthParam, err := strconv.ParseInt(params[2], 10, 64) + if err != nil { + b.Fatal(err) + } + dataLength := int(dataLengthParam) + + poolSync := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + dataHash := bmt.New(poolSync) + hashFunc := testutillocal.NewBMTHasherFunc(0) + treeParams := newTreeParams(hashFunc) + _, data := testutil.SerialData(dataLength, 255, 0) + + for j := 0; j < b.N; j++ { + tgt := newTarget() + jb := newJob(treeParams, tgt, nil, 1, 0) + jb.start() + count := 0 + //log.Info("test vector", "length", dataLength) + for i := 0; i < dataLength; i += chunkSize { + ie := i + chunkSize + if ie > dataLength { + ie = dataLength + } + writeSize := ie - i + dataHash.Reset() + dataHash.SetLength(writeSize) + c, err := dataHash.Write(data[i:ie]) + if err != nil { + jb.destroy() + b.Fatalf("data ref fail: %v", err) + } + if c != ie-i { + jb.destroy() + b.Fatalf("data ref short write: expect %d, got %d", ie-i, c) + } + ref := dataHash.Sum(nil) + jb.write(count, ref) + count += 1 + if ie%(chunkSize*branches) == 0 { + jb = jb.Next() + count = 0 + } + } + dataSections := dataSizeToSectionIndex(dataLength, treeParams.SectionSize) + tgt.Set(dataLength, dataSections, getLevelsFromLength(dataLength, treeParams.SectionSize, treeParams.Branches)) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*1000) + defer cancel() + select { + case <-tgt.Done(): + case <-ctx.Done(): + b.Fatalf("timeout: %v", ctx.Err()) + } + } +} diff --git a/file/hasher/param.go b/file/hasher/param.go new file mode 100644 index 0000000000..33210ae835 --- /dev/null +++ b/file/hasher/param.go @@ -0,0 +1,59 @@ +package hasher + +import ( + "context" + "sync" + + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// defines the boundaries of the hashing job and also contains the hash factory functino of the job +// setting Debug means omitting any automatic behavior (for now it means job processing won't auto-start) +type treeParams struct { + SectionSize int + Branches int + ChunkSize int + Spans []int + Debug bool + hashFunc param.SectionWriterFunc + writerPool sync.Pool + ctx context.Context +} + +func newTreeParams(hashFunc param.SectionWriterFunc) *treeParams { + + h := hashFunc(context.Background()) + p := &treeParams{ + SectionSize: h.SectionSize(), + Branches: h.Branches(), + ChunkSize: h.SectionSize() * h.Branches(), + hashFunc: hashFunc, + } + h.Reset() + log.Trace("new tree params", "sectionsize", p.SectionSize, "branches", p.Branches, "chunksize", p.ChunkSize) + p.writerPool.New = func() interface{} { + hf := p.hashFunc(p.ctx) + //log.Trace("param new hasher", "h", hf) + return hf + } + p.Spans = generateSpanSizes(p.Branches, 9) + return p +} + +func (p *treeParams) SetContext(ctx context.Context) { + p.ctx = ctx +} + +func (p *treeParams) GetContext() context.Context { + return p.ctx +} + +func (p *treeParams) PutWriter(w param.SectionWriter) { + w.Reset() + p.writerPool.Put(w) +} + +func (p *treeParams) GetWriter() param.SectionWriter { + return p.writerPool.Get().(param.SectionWriter) +} diff --git a/file/hasher/pyramid_test.go b/file/hasher/pyramid_test.go new file mode 100644 index 0000000000..a8f03c9ac0 --- /dev/null +++ b/file/hasher/pyramid_test.go @@ -0,0 +1,84 @@ +package hasher + +import ( + "bytes" + "context" + "fmt" + "io" + "strconv" + "strings" + "testing" + + "github.com/ethersphere/swarm/chunk" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/storage" + "github.com/ethersphere/swarm/testutil" +) + +// TestPyramidHasherVector executes the file hasher algorithms on serial input data of periods of 0-254 +// of lengths defined in common_test.go +func TestPyramidHasherVector(t *testing.T) { + t.Skip("only provided for easy reference to bug in case chunkSize*129") + var mismatch int + for i := start; i < end; i++ { + eq := true + dataLength := dataLengths[i] + log.Info("pyramidvector start", "i", i, "l", dataLength) + buf, _ := testutil.SerialData(dataLength, 255, 0) + putGetter := storage.NewHasherStore(&storage.FakeChunkStore{}, storage.MakeHashFunc(storage.BMTHash), false, chunk.NewTag(0, "foo", 0, false)) + + ctx := context.Background() + ref, wait, err := storage.PyramidSplit(ctx, buf, putGetter, putGetter, chunk.NewTag(0, "foo", int64(dataLength/4096+1), false)) + if err != nil { + t.Fatalf(err.Error()) + } + err = wait(ctx) + if err != nil { + t.Fatalf(err.Error()) + } + if ref.Hex() != expected[i] { + mismatch++ + eq = false + } + t.Logf("[%7d+%4d]\t%v\tref: %s\texpect: %s", dataLength/chunkSize, dataLength%chunkSize, eq, ref, expected[i]) + } + + if mismatch != 1 { + t.Fatalf("mismatches: %d/%d", mismatch, end-start) + } +} + +// BenchmarkPyramidHasher establishes the benchmark BenchmarkHasher should be compared to +func BenchmarkPyramidHasher(b *testing.B) { + + for i := start; i < end; i++ { + b.Run(fmt.Sprintf("%d", dataLengths[i]), benchmarkPyramidHasher) + } +} + +func benchmarkPyramidHasher(b *testing.B) { + params := strings.Split(b.Name(), "/") + dataLength, err := strconv.ParseInt(params[1], 10, 64) + if err != nil { + b.Fatal(err) + } + _, data := testutil.SerialData(int(dataLength), 255, 0) + buf := bytes.NewReader(data) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + buf.Seek(0, io.SeekStart) + //putGetter := newTestHasherStore(&storage.FakeChunkStore{}, storage.BMTHash) + putGetter := storage.NewHasherStore(&storage.FakeChunkStore{}, storage.MakeHashFunc(storage.BMTHash), false, chunk.NewTag(0, "foo", 0, false)) + + ctx := context.Background() + _, wait, err := storage.PyramidSplit(ctx, buf, putGetter, putGetter, chunk.NewTag(0, "foo", dataLength/4096+1, false)) + if err != nil { + b.Fatalf(err.Error()) + } + err = wait(ctx) + if err != nil { + b.Fatalf(err.Error()) + } + } +} diff --git a/file/hasher/reference.go b/file/hasher/reference.go new file mode 100644 index 0000000000..8c316ec618 --- /dev/null +++ b/file/hasher/reference.go @@ -0,0 +1,117 @@ +package hasher + +import ( + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// ReferenceHasher is the source-of-truth implementation of the swarm file hashing algorithm +type ReferenceHasher struct { + params *treeParams + cursors []int // section write position, indexed per level + length int // number of bytes written to the data level of the hasher + buffer []byte // keeps data and hashes, indexed by cursors + counts []int // number of sums performed, indexed per level + hasher param.SectionWriter // underlying hasher +} + +// NewReferenceHasher constructs and returns a new ReferenceHasher +func NewReferenceHasher(params *treeParams) *ReferenceHasher { + // TODO: remove when bmt interface is amended + h := params.GetWriter() + return &ReferenceHasher{ + params: params, + cursors: make([]int, 9), + counts: make([]int, 9), + buffer: make([]byte, params.ChunkSize*9), + hasher: h, + } +} + +// Hash computes and returns the root hash of arbitrary data +func (r *ReferenceHasher) Hash(data []byte) []byte { + l := r.params.ChunkSize + for i := 0; i < len(data); i += r.params.ChunkSize { + if len(data)-i < r.params.ChunkSize { + l = len(data) - i + } + r.update(0, data[i:i+l]) + } + for i := 0; i < 9; i++ { + log.Trace("cursor", "lvl", i, "pos", r.cursors[i]) + } + return r.digest() +} + +// write to the data buffer on the specified level +// calls sum if chunk boundary is reached and recursively calls this function for the next level with the acquired bmt hash +// adjusts cursors accordingly +func (r *ReferenceHasher) update(lvl int, data []byte) { + if lvl == 0 { + r.length += len(data) + } + copy(r.buffer[r.cursors[lvl]:r.cursors[lvl]+len(data)], data) + r.cursors[lvl] += len(data) + if r.cursors[lvl]-r.cursors[lvl+1] == r.params.ChunkSize { + ref := r.sum(lvl) + r.update(lvl+1, ref) + r.cursors[lvl] = r.cursors[lvl+1] + } +} + +// calculates and returns the bmt sum of the last written data on the level +func (r *ReferenceHasher) sum(lvl int) []byte { + r.counts[lvl]++ + spanSize := r.params.Spans[lvl] * r.params.ChunkSize + span := (r.length-1)%spanSize + 1 + + toSumSize := r.cursors[lvl] - r.cursors[lvl+1] + + r.hasher.Reset() + r.hasher.SetSpan(span) + r.hasher.Write(r.buffer[r.cursors[lvl+1] : r.cursors[lvl+1]+toSumSize]) + ref := r.hasher.Sum(nil) + return ref +} + +// called after all data has been written +// sums the final chunks of each level +// skips intermediate levels that end on span boundary +func (r *ReferenceHasher) digest() []byte { + + // if we did not end on a chunk boundary, the last chunk hasn't been hashed + // we need to do this first + if r.length%r.params.ChunkSize != 0 { + ref := r.sum(0) + copy(r.buffer[r.cursors[1]:], ref) + r.cursors[1] += len(ref) + r.cursors[0] = r.cursors[1] + } + + // calculate the total number of levels needed to represent the data (including the data level) + targetLevel := getLevelsFromLength(r.length, r.params.SectionSize, r.params.Branches) + + // sum every intermediate level and write to the level above it + for i := 1; i < targetLevel; i++ { + + // if the tree is balanced or if there is a single reference outside a balanced tree on this level + // don't hash it again but pass it on to the next level + if r.counts[i] > 0 { + // TODO: simplify if possible + if r.counts[i-1]-r.params.Spans[targetLevel-1-i] <= 1 { + log.Trace("skip") + r.cursors[i+1] = r.cursors[i] + r.cursors[i] = r.cursors[i-1] + continue + } + } + + ref := r.sum(i) + copy(r.buffer[r.cursors[i+1]:], ref) + r.cursors[i+1] += len(ref) + r.cursors[i] = r.cursors[i+1] + } + + // the first section of the buffer will hold the root hash + return r.buffer[:r.params.SectionSize] +} diff --git a/file/hasher/reference_test.go b/file/hasher/reference_test.go new file mode 100644 index 0000000000..a72999874e --- /dev/null +++ b/file/hasher/reference_test.go @@ -0,0 +1,140 @@ +package hasher + +import ( + "context" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/testutil" + "golang.org/x/crypto/sha3" +) + +// TestManualDanglingChunk is a test script explicitly hashing and writing every individual level in the dangling chunk edge case +// we use a balanced tree with data size of chunkSize*branches, and a single chunk of data +// this case is chosen because it produces the wrong result in the pyramid hasher at the time of writing (master commit hash 4928d989ebd0854d993c10c194e61a5a5455e4f9) +func TestManualDanglingChunk(t *testing.T) { + pool := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + h := bmt.New(pool) + + // to execute the job we need buffers with the following capacities: + // level 0: chunkSize*branches+chunkSize + // level 1: chunkSize + // level 2: sectionSize * 2 + var levels [][]byte + levels = append(levels, nil) + levels = append(levels, make([]byte, chunkSize)) + levels = append(levels, make([]byte, sectionSize*2)) + + // hash the balanced tree portion of the data level and write to level 1 + _, levels[0] = testutil.SerialData(chunkSize*branches+chunkSize, 255, 0) + for i := 0; i < chunkSize*branches; i += chunkSize { + h.Reset() + h.SetSpan(chunkSize) + h.Write(levels[0][i : i+chunkSize]) + copy(levels[1][i/branches:], h.Sum(nil)) + } + refHex := hexutil.Encode(levels[1][:sectionSize]) + correctRefHex := "0xc10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if refHex != correctRefHex { + t.Fatalf("manual dangling single chunk; expected %s, got %s", correctRefHex, refHex) + } + + // write the dangling chunk + // hash it and write the reference on the second section of level 2 + h.Reset() + h.SetSpan(chunkSize) + h.Write(levels[0][chunkSize*branches:]) + copy(levels[2][sectionSize:], h.Sum(nil)) + refHex = hexutil.Encode(levels[2][sectionSize:]) + correctRefHex = "0x81b31d9a7f6c377523e8769db021091df23edd9fd7bd6bcdf11a22f518db6006" + if refHex != correctRefHex { + t.Fatalf("manual dangling single chunk; expected %s, got %s", correctRefHex, refHex) + } + + // hash the chunk on level 1 and write into the first section of level 2 + h.Reset() + h.SetSpan(chunkSize * branches) + h.Write(levels[1]) + copy(levels[2], h.Sum(nil)) + refHex = hexutil.Encode(levels[2][:sectionSize]) + correctRefHex = "0x3047d841077898c26bbe6be652a2ec590a5d9bd7cd45d290ea42511b48753c09" + if refHex != correctRefHex { + t.Fatalf("manual dangling balanced tree; expected %s, got %s", correctRefHex, refHex) + } + + // hash the two sections on level 2 to obtain the root hash + h.Reset() + h.SetSpan(chunkSize*branches + chunkSize) + h.Write(levels[2]) + ref := h.Sum(nil) + refHex = hexutil.Encode(ref) + correctRefHex = "0xb8e1804e37a064d28d161ab5f256cc482b1423d5cd0a6b30fde7b0f51ece9199" + if refHex != correctRefHex { + t.Fatalf("manual dangling root; expected %s, got %s", correctRefHex, refHex) + } +} + +// TestReferenceFileHasherVector executes the file hasher algorithms on serial input data of periods of 0-254 +// of lengths defined in common_test.go +// +// the "expected" array in common_test.go is generated by this implementation, and test failure due to +// result mismatch is nothing else than an indication that something has changed in the reference filehasher +// or the underlying hashing algorithm +func TestReferenceHasherVector(t *testing.T) { + + hashFunc := func(_ context.Context) param.SectionWriter { + pool := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + return bmt.New(pool) + } + params := newTreeParams(hashFunc) + var mismatch int + for i := start; i < end; i++ { + dataLength := dataLengths[i] + log.Info("start", "i", i, "len", dataLength) + rh := NewReferenceHasher(params) + _, data := testutil.SerialData(dataLength, 255, 0) + refHash := rh.Hash(data) + eq := true + if expected[i] != fmt.Sprintf("%x", refHash) { + mismatch++ + eq = false + } + t.Logf("[%7d+%4d]\t%v\tref: %x\texpect: %s", dataLength/chunkSize, dataLength%chunkSize, eq, refHash, expected[i]) + } + if mismatch > 0 { + t.Fatalf("mismatches: %d/%d", mismatch, end-start) + } +} + +// BenchmarkReferenceHasher establishes a baseline for a fully synchronous file hashing operation +// it will be vastly inefficient +func BenchmarkReferenceHasher(b *testing.B) { + for i := start; i < end; i++ { + b.Run(fmt.Sprintf("%d", dataLengths[i]), benchmarkReferenceHasher) + } +} + +func benchmarkReferenceHasher(b *testing.B) { + benchParams := strings.Split(b.Name(), "/") + dataLength, err := strconv.ParseInt(benchParams[1], 10, 64) + if err != nil { + b.Fatal(err) + } + hashFunc := func(_ context.Context) param.SectionWriter { + pool := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, bmt.PoolSize) + return bmt.New(pool) + } + params := newTreeParams(hashFunc) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, data := testutil.SerialData(int(dataLength), 255, 0) + fh := NewReferenceHasher(params) + fh.Hash(data) + } +} diff --git a/file/hasher/target.go b/file/hasher/target.go new file mode 100644 index 0000000000..9d566fa490 --- /dev/null +++ b/file/hasher/target.go @@ -0,0 +1,61 @@ +package hasher + +import ( + "sync" + + "github.com/ethersphere/swarm/log" +) + +// passed to a job to determine at which data lengths and levels a job should terminate +type target struct { + size int32 // bytes written + sections int32 // sections written + level int32 // target level calculated from bytes written against branching factor and sector size + resultC chan []byte // channel to receive root hash + doneC chan struct{} // when this channel is closed all jobs will calculate their end write count + mu sync.Mutex +} + +func newTarget() *target { + return &target{ + resultC: make(chan []byte), + doneC: make(chan struct{}), + } +} + +// Set is called when the final length of the data to be written is known +// TODO: method can be simplified to calculate sections and level internally +func (t *target) Set(size int, sections int, level int) { + t.mu.Lock() + defer t.mu.Unlock() + t.size = int32(size) + t.sections = int32(sections) + t.level = int32(level) + log.Trace("target set", "size", t.size, "section", t.sections, "level", t.level) + close(t.doneC) +} + +// Count returns the total section count for the target +// it should only be called after Set() +func (t *target) Count() int { + t.mu.Lock() + defer t.mu.Unlock() + return int(t.sections) + 1 +} + +func (t *target) Level() int { + t.mu.Lock() + defer t.mu.Unlock() + return int(t.level) +} + +func (t *target) Size() int { + t.mu.Lock() + defer t.mu.Unlock() + return int(t.size) +} + +// Done returns the channel in which the root hash will be sent +func (t *target) Done() <-chan []byte { + return t.resultC +} diff --git a/file/hasher/util.go b/file/hasher/util.go new file mode 100644 index 0000000000..8dd8b4a27f --- /dev/null +++ b/file/hasher/util.go @@ -0,0 +1,58 @@ +package hasher + +import ( + "math" +) + +// TODO: level 0 should be SectionSize() not Branches() +// generates a dictionary of maximum span lengths per level represented by one SectionSize() of data +func generateSpanSizes(branches int, levels int) []int { + spans := make([]int, levels) + span := 1 + for i := 0; i < 9; i++ { + spans[i] = span + span *= branches + } + return spans +} + +// calculates the section index of the given byte size +func dataSizeToSectionIndex(length int, sectionSize int) int { + return (length - 1) / sectionSize +} + +// calculates the section count of the given byte size +func dataSizeToSectionCount(length int, sectionSize int) int { + return dataSizeToSectionIndex(length, sectionSize) + 1 +} + +// calculates the corresponding level section for a data section +func dataSectionToLevelSection(p *treeParams, lvl int, sections int) int { + span := p.Spans[lvl] + return sections / span +} + +// calculates the lower data section boundary of a level for which a data section is contained +// the higher level use is to determine whether the final data section written falls within +// a certain level's span +func dataSectionToLevelBoundary(p *treeParams, lvl int, section int) int { + span := p.Spans[lvl+1] + spans := section / span + spanBytes := spans * span + //log.Trace("levelboundary", "spans", spans, "section", section, "span", span) + return spanBytes +} + +// TODO: use params instead of sectionSize, branches +// calculate the last level index which a particular data section count will result in. +// the returned level will be the level of the root hash +func getLevelsFromLength(l int, sectionSize int, branches int) int { + if l == 0 { + return 0 + } else if l <= sectionSize*branches { + return 1 + } + c := (l - 1) / (sectionSize) + + return int(math.Log(float64(c))/math.Log(float64(branches)) + 1) +} diff --git a/file/hasher/util_test.go b/file/hasher/util_test.go new file mode 100644 index 0000000000..f5678b9a38 --- /dev/null +++ b/file/hasher/util_test.go @@ -0,0 +1,91 @@ +package hasher + +import "testing" + +// TestLevelsFromLength verifies getLevelsFromLength +func TestLevelsFromLength(t *testing.T) { + + sizes := []int{sectionSize, chunkSize, chunkSize + sectionSize, chunkSize * branches, chunkSize*branches + 1} + expects := []int{1, 1, 2, 2, 3} + + for i, size := range sizes { + lvl := getLevelsFromLength(size, sectionSize, branches) + if expects[i] != lvl { + t.Fatalf("size %d, expected %d, got %d", size, expects[i], lvl) + } + } +} + +// TestDataSizeToSection verifies testDataSizeToSectionIndex +func TestDataSizeToSectionIndex(t *testing.T) { + + sizes := []int{chunkSize - 1, chunkSize, chunkSize + 1} + expects := []int{branches - 1, branches - 1, branches} + + for j, size := range sizes { + r := dataSizeToSectionIndex(size, sectionSize) + expect := expects[j] + if expect != r { + t.Fatalf("size %d section %d: expected %d, got %d", size, sectionSize, expect, r) + } + } + +} + +// TestsDataSectionToLevelSection verifies dataSectionToLevelSection +func TestDataSectionToLevelSection(t *testing.T) { + + params := newTreeParams(dummyHashFunc) + sections := []int{0, branches - 1, branches, branches + 1, branches * 2, branches*2 + 1, branches * branches} + levels := []int{1, 2} + expects := []int{ + 0, 0, 1, 1, 2, 2, 128, + 0, 0, 0, 0, 0, 0, 1, + } + + for i, lvl := range levels { + for j, section := range sections { + r := dataSectionToLevelSection(params, lvl, section) + k := i*len(sections) + j + expect := expects[k] + if expect != r { + t.Fatalf("levelsection size %d level %d: expected %d, got %d", section, lvl, expect, r) + } + } + } + +} + +// TestDataSectionToLevelBoundary verifies dataSectionToLevelBoundary +func TestDataSectionToLevelBoundary(t *testing.T) { + params := newTreeParams(dummyHashFunc) + size := chunkSize*branches + chunkSize*2 + section := dataSizeToSectionIndex(size, sectionSize) + lvl := 1 + expect := branches * branches + + r := dataSectionToLevelBoundary(params, lvl, section) + if expect != r { + t.Fatalf("levelboundary size %d level %d: expected %d, got %d", section, lvl, expect, r) + } + + size = chunkSize*branches*branches + chunkSize*2 + section = dataSizeToSectionIndex(size, sectionSize) + lvl = 1 + expect = branches * branches * branches + + r = dataSectionToLevelBoundary(params, lvl, section) + if expect != r { + t.Fatalf("levelboundary size %d level %d: expected %d, got %d", section, lvl, expect, r) + } + + size = chunkSize*branches + chunkSize*2 + section = dataSizeToSectionIndex(size, sectionSize) + lvl = 2 + expect = 0 + + r = dataSectionToLevelBoundary(params, lvl, section) + if expect != r { + t.Fatalf("levelboundary size %d level %d: expected %d, got %d", section, lvl, expect, r) + } +} diff --git a/file/split.go b/file/split.go new file mode 100644 index 0000000000..fdcf213e57 --- /dev/null +++ b/file/split.go @@ -0,0 +1,47 @@ +package file + +import ( + "io" + + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// TODO: grow buffer on demand to reduce allocs +// Splitter returns the result of a data stream from a bmt.SectionWriter +type Splitter struct { + r io.Reader + w param.SectionWriter +} + +// NewSplitter creates a new Splitter object +func NewSplitter(r io.Reader, w param.SectionWriter) *Splitter { + s := &Splitter{ + r: r, + w: w, + } + return s +} + +// Split is a blocking call that consumes and passes data from its reader to its SectionWriter +// according to the SectionWriter's SectionSize +// On EOF from the reader it calls Sum on the bmt.SectionWriter and returns the result +func (s *Splitter) Split() ([]byte, error) { + wc := 0 + l := 0 + for { + d := make([]byte, s.w.SectionSize()) + c, err := s.r.Read(d) + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + log.Trace("split read", "c", c, "wc", c, "l", l) + s.w.Write(d) + wc++ + l += c + } + return s.w.Sum(nil), nil +} diff --git a/file/split_test.go b/file/split_test.go new file mode 100644 index 0000000000..ab1de3ba08 --- /dev/null +++ b/file/split_test.go @@ -0,0 +1,85 @@ +package file + +import ( + "context" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/file/hasher" + "github.com/ethersphere/swarm/file/store" + "github.com/ethersphere/swarm/file/testutillocal" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" + "github.com/ethersphere/swarm/storage" + "github.com/ethersphere/swarm/testutil" +) + +const ( + sectionSize = 32 + branches = 128 + chunkSize = 4096 +) + +func init() { + testutil.Init() +} + +var ( + errFunc = func(err error) { + log.Error("split writer pipeline error", "err", err) + } +) + +// TestSplit creates a Splitter with a reader with one chunk of serial data and +// a Hasher as the underlying param.SectionWriter +// It verifies the returned result +func TestSplit(t *testing.T) { + + hashFunc := testutillocal.NewBMTHasherFunc(0) + h := hasher.New(hashFunc) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.Init(ctx, errFunc) + + r, _ := testutil.SerialData(chunkSize, 255, 0) + s := NewSplitter(r, h) + ref, err := s.Split() + if err != nil { + t.Fatal(err) + } + refHex := hexutil.Encode(ref) + correctRefHex := "0xc10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if refHex != correctRefHex { + t.Fatalf("split, expected %s, got %s", correctRefHex, refHex) + } +} + +// TestSplitWithDataFileStore verifies chunk.Store sink result for data hashing +func TestSplitWithDataFileStore(t *testing.T) { + hashFunc := testutillocal.NewBMTHasherFunc(128) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + chunkStore := &storage.FakeChunkStore{} + storeFunc := func(_ context.Context) param.SectionWriter { + h := store.New(chunkStore, hashFunc) + h.Init(ctx, errFunc) + return h + } + + h := hasher.New(storeFunc) + h.Init(ctx, errFunc) + + r, _ := testutil.SerialData(chunkSize, 255, 0) + s := NewSplitter(r, h) + ref, err := s.Split() + if err != nil { + t.Fatal(err) + } + refHex := hexutil.Encode(ref) + correctRefHex := "0xc10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if refHex != correctRefHex { + t.Fatalf("split, expected %s, got %s", correctRefHex, refHex) + } +} diff --git a/file/store/store.go b/file/store/store.go new file mode 100644 index 0000000000..8d238d71ef --- /dev/null +++ b/file/store/store.go @@ -0,0 +1,109 @@ +package store + +import ( + "context" + + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/chunk" + "github.com/ethersphere/swarm/log" + "github.com/ethersphere/swarm/param" +) + +// FileStore implements param.SectionWriter +// It intercepts data between source and hasher +// and compiles the data with the received hash on sum +// to a chunk to be passed to underlying chunk.Store.Put +type FileStore struct { + chunkStore chunk.Store + w param.SectionWriter + ctx context.Context + data [][]byte + span int + errFunc func(error) +} + +// New creates a new FileStore with the supplied chunk.Store +func New(chunkStore chunk.Store, writerFunc param.SectionWriterFunc) *FileStore { + f := &FileStore{ + chunkStore: chunkStore, + } + f.w = writerFunc(f.ctx) + return f +} + +func (f *FileStore) SetWriter(hashFunc param.SectionWriterFunc) param.SectionWriter { + f.w = hashFunc(f.ctx) + return f +} + +// Init implements param.SectionWriter +func (f *FileStore) Init(ctx context.Context, errFunc func(error)) { + f.ctx = ctx + f.errFunc = errFunc +} + +// Reset implements param.SectionWriter +func (f *FileStore) Reset() { + f.span = 0 + f.data = [][]byte{} + f.w.Reset() +} + +func (f *FileStore) SeekSection(index int) { + f.w.SeekSection(index) +} + +// Write implements param.SectionWriter +// it asynchronously writes to the underlying writer while caching the data slice +func (f *FileStore) Write(b []byte) (int, error) { + f.data = append(f.data, b) + return f.w.Write(b) +} + +// Sum implements param.SectionWriter +// calls underlying writer's Sum and sends the result with data as a chunk to chunk.Store +func (f *FileStore) Sum(b []byte) []byte { + ref := f.w.Sum(b) + go func(ref []byte) { + b = bmt.LengthToSpan(f.span) + for _, data := range f.data { + b = append(b, data...) + } + ch := chunk.NewChunk(ref, b) + _, err := f.chunkStore.Put(f.ctx, chunk.ModePutUpload, ch) + log.Trace("filestore put chunk", "ch", ch) + if err != nil { + f.errFunc(err) + } + }(ref) + return ref +} + +func (f *FileStore) SetSpan(length int) { + f.span = length + f.w.SetSpan(length) +} + +func (f *FileStore) SetLength(length int) { + f.w.SetLength(length) +} + +// SectionSize implements param.SectionWriter +func (f *FileStore) BlockSize() int { + return f.w.BlockSize() +} + +// SectionSize implements param.SectionWriter +func (f *FileStore) SectionSize() int { + return f.w.SectionSize() +} + +// DigestSize implements param.SectionWriter +func (f *FileStore) Size() int { + return f.w.Size() +} + +// Branches implements param.SectionWriter +func (f *FileStore) Branches() int { + return f.w.Branches() +} diff --git a/file/store/store_test.go b/file/store/store_test.go new file mode 100644 index 0000000000..4c3620df23 --- /dev/null +++ b/file/store/store_test.go @@ -0,0 +1,94 @@ +package store + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/chunk" + "github.com/ethersphere/swarm/file/testutillocal" + "github.com/ethersphere/swarm/storage" + "github.com/ethersphere/swarm/testutil" +) + +const ( + sectionSize = 32 + branches = 128 + chunkSize = 4096 +) + +func init() { + testutil.Init() +} + +// wraps storage.FakeChunkStore to intercept incoming chunk +type testChunkStore struct { + *storage.FakeChunkStore + chunkC chan<- chunk.Chunk +} + +func newTestChunkStore(chunkC chan<- chunk.Chunk) *testChunkStore { + return &testChunkStore{ + FakeChunkStore: &storage.FakeChunkStore{}, + chunkC: chunkC, + } +} + +// Put overrides storage.FakeChunkStore.Put +func (s *testChunkStore) Put(_ context.Context, _ chunk.ModePut, chs ...chunk.Chunk) ([]bool, error) { + for _, ch := range chs { + s.chunkC <- ch + } + return s.FakeChunkStore.Put(nil, 0, chs...) +} + +// TestStoreWithHasher writes a single chunk and verifies the asynchronusly received chunk +// through the underlying chunk store +func TestStoreWithHasher(t *testing.T) { + + hashFunc := testutillocal.NewBMTHasherFunc(128) + + // initialize chunk store with channel to intercept chunk + chunkC := make(chan chunk.Chunk) + store := newTestChunkStore(chunkC) + + // initialize FileStore + h := New(store, hashFunc) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + h.Init(ctx, nil) + + // Write data to Store + _, data := testutil.SerialData(chunkSize, 255, 0) + span := bmt.LengthToSpan(chunkSize) + go func() { + for i := 0; i < chunkSize; i += sectionSize { + h.SeekSection(i / sectionSize) + h.Write(data[i : i+sectionSize]) + } + h.SetSpan(chunkSize) + h.SetLength(chunkSize) + h.Sum(nil) + }() + + // capture chunk and verify contents + select { + case ch := <-chunkC: + if !bytes.Equal(ch.Data()[:8], span) { + t.Fatalf("chunk span; expected %x, got %x", span, ch.Data()[:8]) + } + if !bytes.Equal(ch.Data()[8:], data) { + t.Fatalf("chunk data; expected %x, got %x", data, ch.Data()[8:]) + } + refHex := ch.Address().Hex() + correctRefHex := "c10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if refHex != correctRefHex { + t.Fatalf("chunk ref; expected %s, got %s", correctRefHex, refHex) + } + + case <-ctx.Done(): + t.Fatalf("timeout %v", ctx.Err()) + } +} diff --git a/file/testutillocal/cache.go b/file/testutillocal/cache.go new file mode 100644 index 0000000000..7b2968139b --- /dev/null +++ b/file/testutillocal/cache.go @@ -0,0 +1,106 @@ +package testutillocal + +import ( + "context" + + "github.com/ethersphere/swarm/param" +) + +var ( + defaultSectionSize = 32 + defaultBranches = 128 +) + +type Cache struct { + data map[int][]byte + index int + w param.SectionWriter +} + +func NewCache() *Cache { + return &Cache{ + data: make(map[int][]byte), + } +} + +func (c *Cache) Init(_ context.Context, _ func(error)) { +} + +func (c *Cache) SetWriter(writeFunc param.SectionWriterFunc) param.SectionWriter { + c.w = writeFunc(nil) + return c +} + +func (c *Cache) SetSpan(length int) { + if c.w != nil { + c.w.SetSpan(length) + } +} + +func (c *Cache) SetLength(length int) { + if c.w != nil { + c.w.SetLength(length) + } +} + +func (c *Cache) SeekSection(offset int) { + c.index = offset + if c.w != nil { + c.w.SeekSection(offset) + } +} + +func (c *Cache) Write(b []byte) (int, error) { + c.data[c.index] = b + if c.w != nil { + return c.w.Write(b) + } + return len(b), nil +} + +func (c *Cache) Sum(b []byte) []byte { + if c.w == nil { + return nil + } + return c.w.Sum(b) +} + +func (c *Cache) Reset() { + if c.w == nil { + return + } + c.w.Reset() +} + +func (c *Cache) SectionSize() int { + if c.w != nil { + return c.w.SectionSize() + } + return defaultSectionSize +} + +func (c *Cache) BlockSize() int { + return c.SectionSize() +} + +func (c *Cache) Size() int { + if c.w != nil { + return c.w.Size() + } + return defaultSectionSize +} + +func (c *Cache) Branches() int { + if c.w != nil { + return c.w.Branches() + } + return defaultBranches +} + +func (c *Cache) Get(index int) []byte { + return c.data[index] +} + +func (c *Cache) Delete(index int) { + delete(c.data, index) +} diff --git a/file/testutillocal/cache_test.go b/file/testutillocal/cache_test.go new file mode 100644 index 0000000000..cf43f0d3cb --- /dev/null +++ b/file/testutillocal/cache_test.go @@ -0,0 +1,53 @@ +package testutillocal + +import ( + "bytes" + "context" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethersphere/swarm/testutil" +) + +const ( + sectionSize = 32 + chunkSize = 4096 +) + +func init() { + testutil.Init() +} + +func TestCache(t *testing.T) { + c := NewCache() + c.Init(context.Background(), func(error) {}) + _, data := testutil.SerialData(chunkSize, 255, 0) + c.Write(data) + cachedData := c.Get(0) + if !bytes.Equal(cachedData, data) { + t.Fatalf("cache data; expected %x, got %x", data, cachedData) + } +} + +func TestCacheLink(t *testing.T) { + + hashFunc := NewBMTHasherFunc(0) + + c := NewCache() + c.Init(context.Background(), func(error) {}) + c.SetWriter(hashFunc) + _, data := testutil.SerialData(chunkSize, 255, 0) + c.SeekSection(-1) + c.Write(data) + ref := c.Sum(nil) + refHex := hexutil.Encode(ref) + correctRefHex := "0xc10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef" + if refHex != correctRefHex { + t.Fatalf("cache link; expected %s, got %s", correctRefHex, refHex) + } + + c.Delete(0) + if _, ok := c.data[0]; ok { + t.Fatalf("delete; expected not found") + } +} diff --git a/file/testutillocal/hash.go b/file/testutillocal/hash.go new file mode 100644 index 0000000000..d30e513f37 --- /dev/null +++ b/file/testutillocal/hash.go @@ -0,0 +1,24 @@ +package testutillocal + +import ( + "context" + + "github.com/ethersphere/swarm/bmt" + "github.com/ethersphere/swarm/param" + "golang.org/x/crypto/sha3" +) + +var ( + branches = 128 +) + +func NewBMTHasherFunc(poolSize int) param.SectionWriterFunc { + if poolSize == 0 { + poolSize = bmt.PoolSize + } + poolAsync := bmt.NewTreePool(sha3.NewLegacyKeccak256, branches, poolSize) + refHashFunc := func(_ context.Context) param.SectionWriter { + return bmt.New(poolAsync).NewAsyncWriter(false) + } + return refHashFunc +} diff --git a/param/hash.go b/param/hash.go new file mode 100644 index 0000000000..114fa0ad6c --- /dev/null +++ b/param/hash.go @@ -0,0 +1,7 @@ +package param + +import "golang.org/x/crypto/sha3" + +var ( + HashFunc = sha3.NewLegacyKeccak256 +) diff --git a/param/io.go b/param/io.go new file mode 100644 index 0000000000..3eb14bce67 --- /dev/null +++ b/param/io.go @@ -0,0 +1,19 @@ +package param + +import ( + "context" + "hash" +) + +type SectionWriterFunc func(ctx context.Context) SectionWriter + +type SectionWriter interface { + hash.Hash + Init(ctx context.Context, errFunc func(error)) // errFunc is used for asynchronous components to signal error and termination + SetWriter(hashFunc SectionWriterFunc) SectionWriter // chain another SectionWriter the current instance + SeekSection(section int) // sets cursor that next Write() will write to + SetLength(length int) // set total number of bytes that will be written to SectionWriter + SetSpan(length int) // set data span of chunk + SectionSize() int // section size of this SectionWriter + Branches() int // branch factor of this SectionWriter +} diff --git a/storage/chunker_test.go b/storage/chunker_test.go index fd1af937f2..3e1158d13f 100644 --- a/storage/chunker_test.go +++ b/storage/chunker_test.go @@ -151,7 +151,8 @@ func TestSha3ForCorrectness(t *testing.T) { rawSha3Output := rawSha3.Sum(nil) sha3FromMakeFunc := MakeHashFunc(SHA3Hash)() - sha3FromMakeFunc.ResetWithLength(input[:8]) + sha3FromMakeFunc.Reset() + sha3FromMakeFunc.SetSpanBytes(input[:8]) sha3FromMakeFunc.Write(input[8:]) sha3FromMakeFuncOutput := sha3FromMakeFunc.Sum(nil) diff --git a/storage/common_test.go b/storage/common_test.go index a65a686943..e625cd8091 100644 --- a/storage/common_test.go +++ b/storage/common_test.go @@ -151,7 +151,8 @@ func testStoreCorrect(m ChunkStore, n int, t *testing.T) { } hasher := MakeHashFunc(DefaultHash)() data := chunk.Data() - hasher.ResetWithLength(data[:8]) + hasher.Reset() + hasher.SetSpanBytes(data[:8]) hasher.Write(data[8:]) exp := hasher.Sum(nil) if !bytes.Equal(h, exp) { diff --git a/storage/encryption/encryption.go b/storage/encryption/encryption.go index a5ec2d5efa..6fbdab062b 100644 --- a/storage/encryption/encryption.go +++ b/storage/encryption/encryption.go @@ -31,14 +31,12 @@ type Key []byte type Encryption interface { Encrypt(data []byte) ([]byte, error) Decrypt(data []byte) ([]byte, error) - Reset() } type encryption struct { key Key // the encryption key (hashSize bytes long) keyLen int // length of the key = length of blockcipher block padding int // encryption will pad the data upto this if > 0 - index int // counter index initCtr uint32 // initial counter used for counter mode blockcipher hashFunc func() hash.Hash // hasher constructor function } @@ -81,24 +79,18 @@ func (e *encryption) Decrypt(data []byte) ([]byte, error) { return out, nil } -// Reset resets the counter. It is only safe to call after an encryption operation is completed -// After Reset is called, the Encryption object can be re-used for other data -func (e *encryption) Reset() { - e.index = 0 -} - -// split up input into keylength segments and encrypt sequentially +// func (e *encryption) transform(in, out []byte) { inLength := len(in) wg := sync.WaitGroup{} wg.Add((inLength-1)/e.keyLen + 1) for i := 0; i < inLength; i += e.keyLen { l := min(e.keyLen, inLength-i) + // call transformations per segment (asyncronously) go func(i int, x, y []byte) { defer wg.Done() e.Transcrypt(i, x, y) - }(e.index, in[i:i+l], out[i:i+l]) - e.index++ + }(i/e.keyLen, in[i:i+l], out[i:i+l]) } // pad the rest if out is longer pad(out[inLength:]) diff --git a/storage/encryption/encryption_test.go b/storage/encryption/encryption_test.go index 80ae3da4ef..c89ab184df 100644 --- a/storage/encryption/encryption_test.go +++ b/storage/encryption/encryption_test.go @@ -18,7 +18,6 @@ package encryption import ( "bytes" - crand "crypto/rand" "testing" "github.com/ethereum/go-ethereum/common" @@ -38,7 +37,6 @@ func init() { if err != nil { panic(err.Error()) } - testutil.Init() } func TestEncryptDataLongerThanPadding(t *testing.T) { @@ -134,7 +132,6 @@ func testEncryptDecryptIsIdentity(t *testing.T, padding int, initCtr uint32, dat t.Fatalf("Expected no error got %v", err) } - enc.Reset() decrypted, err := enc.Decrypt(encrypted) if err != nil { t.Fatalf("Expected no error got %v", err) @@ -152,42 +149,3 @@ func testEncryptDecryptIsIdentity(t *testing.T, padding int, initCtr uint32, dat t.Fatalf("Expected decrypted %v got %v", common.Bytes2Hex(data), common.Bytes2Hex(decrypted)) } } - -// TestEncryptSectioned tests that the cipherText is the same regardless of size of data input buffer -func TestEncryptSectioned(t *testing.T) { - data := make([]byte, 4096) - c, err := crand.Read(data) - if err != nil { - t.Fatal(err) - } - if c < 4096 { - t.Fatalf("short read %d", c) - } - - key := make([]byte, KeyLength) - c, err = crand.Read(key) - if err != nil { - t.Fatal(err) - } - if c < KeyLength { - t.Fatalf("short read %d", c) - } - - enc := New(key, 0, uint32(42), sha3.NewLegacyKeccak256) - whole, err := enc.Encrypt(data) - if err != nil { - t.Fatal(err) - } - - enc.Reset() - for i := 0; i < 4096; i += KeyLength { - cipher, err := enc.Encrypt(data[i : i+KeyLength]) - if err != nil { - t.Fatal(err) - } - wholeSection := whole[i : i+KeyLength] - if !bytes.Equal(cipher, wholeSection) { - t.Fatalf("index %d, expected %x, got %x", i/KeyLength, wholeSection, cipher) - } - } -} diff --git a/storage/hasherstore.go b/storage/hasherstore.go index 4890219a15..d81ffba5aa 100644 --- a/storage/hasherstore.go +++ b/storage/hasherstore.go @@ -184,8 +184,9 @@ func (h *hasherStore) startWait(ctx context.Context) { func (h *hasherStore) createHash(chunkData ChunkData) Address { hasher := h.hashFunc() - hasher.ResetWithLength(chunkData[:8]) // 8 bytes of length - hasher.Write(chunkData[8:]) // minus 8 []byte length + hasher.Reset() + hasher.SetSpanBytes(chunkData[:8]) // 8 bytes of length + hasher.Write(chunkData[8:]) // minus 8 []byte length return hasher.Sum(nil) } diff --git a/storage/swarmhasher.go b/storage/swarmhasher.go index fae03f0c72..0cbc12556c 100644 --- a/storage/swarmhasher.go +++ b/storage/swarmhasher.go @@ -28,14 +28,14 @@ const ( type SwarmHash interface { hash.Hash - ResetWithLength([]byte) + SetSpanBytes([]byte) } type HashWithLength struct { hash.Hash } -func (h *HashWithLength) ResetWithLength(length []byte) { +func (h *HashWithLength) SetSpanBytes(length []byte) { h.Reset() h.Write(length) } diff --git a/storage/types.go b/storage/types.go index a4b102a62c..9fa258495d 100644 --- a/storage/types.go +++ b/storage/types.go @@ -93,7 +93,8 @@ func GenerateRandomChunk(dataSize int64) Chunk { sdata := make([]byte, dataSize+8) rand.Read(sdata[8:]) binary.LittleEndian.PutUint64(sdata[:8], uint64(dataSize)) - hasher.ResetWithLength(sdata[:8]) + hasher.Reset() + hasher.SetSpanBytes(sdata[:8]) hasher.Write(sdata[8:]) return NewChunk(hasher.Sum(nil), sdata) } @@ -202,7 +203,8 @@ func (v *ContentAddressValidator) Validate(ch Chunk) bool { } hasher := v.Hasher() - hasher.ResetWithLength(data[:8]) + hasher.Reset() + hasher.SetSpanBytes(data[:8]) hasher.Write(data[8:]) hash := hasher.Sum(nil) diff --git a/testutil/data.go b/testutil/data.go new file mode 100644 index 0000000000..f3bea59e91 --- /dev/null +++ b/testutil/data.go @@ -0,0 +1,15 @@ +package testutil + +import ( + "bytes" + "io" +) + +func SerialData(l int, mod int, offset int) (r io.Reader, slice []byte) { + slice = make([]byte, l) + for i := 0; i < len(slice); i++ { + slice[i] = byte((i + offset) % mod) + } + r = io.LimitReader(bytes.NewReader(slice), int64(l)) + return +}