From a638f1b95d69b710930c94533103f17e3a2d7b74 Mon Sep 17 00:00:00 2001 From: Ralph Pichler Date: Wed, 5 Feb 2020 15:21:55 +0100 Subject: [PATCH] state, swap: add batch support to state store (#2088) --- state/dbstore.go | 33 +++++++++++++++++++++ state/dbstore_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++ swap/swap.go | 17 ++++++++--- 3 files changed, 115 insertions(+), 4 deletions(-) diff --git a/state/dbstore.go b/state/dbstore.go index 4d651e583b..5fa978675f 100644 --- a/state/dbstore.go +++ b/state/dbstore.go @@ -36,6 +36,7 @@ type Store interface { Put(key string, i interface{}) (err error) Delete(key string) (err error) Iterate(prefix string, iterFunc iterFunction) (err error) + WriteBatch(batch *StoreBatch) (err error) Close() error } @@ -132,3 +133,35 @@ func (s *DBStore) Iterate(prefix string, iterFunc iterFunction) (err error) { func (s *DBStore) Close() error { return s.db.Close() } + +// StoreBatch is a wrapper around a leveldb batch that takes care of the proper encoding. +type StoreBatch struct { + leveldb.Batch +} + +// Put encodes the value and puts a corresponding Put operation into the underlying batch. +// This only returns an error if the encoding failed. +func (b *StoreBatch) Put(key string, i interface{}) (err error) { + var bytes []byte + if marshaler, ok := i.(encoding.BinaryMarshaler); ok { + if bytes, err = marshaler.MarshalBinary(); err != nil { + return err + } + } else { + if bytes, err = json.Marshal(i); err != nil { + return err + } + } + b.Batch.Put([]byte(key), bytes) + return nil +} + +// Delete adds a delete operation to the underlying batch. +func (b *StoreBatch) Delete(key string) { + b.Batch.Delete([]byte(key)) +} + +// WriteBatch executes the batch on the underlying database. +func (s *DBStore) WriteBatch(batch *StoreBatch) error { + return s.db.Write(&batch.Batch, nil) +} diff --git a/state/dbstore_test.go b/state/dbstore_test.go index 1eddbe560e..27406c9371 100644 --- a/state/dbstore_test.go +++ b/state/dbstore_test.go @@ -86,6 +86,13 @@ func TestDBStore(t *testing.T) { } testStoreIterator(t, iteratedStore) + + batchedStore, err := NewDBStore(dir) + if err != nil { + t.Fatal(err) + } + + testStoreBatch(t, batchedStore) } func testStore(t *testing.T, store Store) { @@ -175,3 +182,65 @@ func testStoreIterator(t *testing.T, store Store) { t.Fatalf("expected store entries to be %v, are %v instead", expectedEntries, entries) } } + +func testStoreBatch(t *testing.T, store Store) { + defer store.Close() + + batch := new(StoreBatch) + + var val1 uint64 = 1 + var val2 uint64 = 2 + + err := batch.Put("key1", val1) + if err != nil { + t.Fatal(err) + } + + err = batch.Put("key2", val2) + if err != nil { + t.Fatal(err) + } + + err = store.WriteBatch(batch) + if err != nil { + t.Fatal(err) + } + + var result uint64 + err = store.Get("key1", &result) + if err != nil { + t.Fatal(err) + } + + if result != val1 { + t.Fatalf("expected key1 to be %d, was %d instead", val1, result) + } + + err = store.Get("key2", &result) + if err != nil { + t.Fatal(err) + } + + if result != val2 { + t.Fatalf("expected key1 to be %d, was %d instead", val2, result) + } + + batch = new(StoreBatch) + batch.Delete("key1") + batch.Delete("key2") + + err = store.WriteBatch(batch) + if err != nil { + t.Fatal(err) + } + + err = store.Get("key1", &result) + if err != ErrNotFound { + t.Fatal("expected key1 to be deleted") + } + + err = store.Get("key2", &result) + if err != ErrNotFound { + t.Fatal("expected key2 to be deleted") + } +} diff --git a/swap/swap.go b/swap/swap.go index 6218095358..bd2d5c8dd9 100644 --- a/swap/swap.go +++ b/swap/swap.go @@ -465,16 +465,25 @@ func (s *Swap) handleConfirmChequeMsg(ctx context.Context, p *Peer, msg *Confirm return fmt.Errorf("ignoring confirm msg, unexpected cheque, confirm message cheque %s, expected %s", cheque, p.getPendingCheque()) } - err := p.setLastSentCheque(cheque) + batch := new(state.StoreBatch) + err := batch.Put(sentChequeKey(p.ID()), cheque) if err != nil { - return protocols.Break(fmt.Errorf("setLastSentCheque failed: %w", err)) + return protocols.Break(fmt.Errorf("encoding cheque failed: %w", err)) } - err = p.setPendingCheque(nil) + err = batch.Put(pendingChequeKey(p.ID()), nil) if err != nil { - return protocols.Break(fmt.Errorf("setPendingCheque failed: %w", err)) + return protocols.Break(fmt.Errorf("encoding pending cheque failed: %w", err)) } + err = s.store.WriteBatch(batch) + if err != nil { + return protocols.Break(fmt.Errorf("could not write cheque to database: %w", err)) + } + + p.lastSentCheque = cheque + p.pendingCheque = nil + return nil }