diff --git a/bptree.go b/bptree.go index 8ed561b4..596dba85 100644 --- a/bptree.go +++ b/bptree.go @@ -1,6 +1,7 @@ package lotusdb import ( + "bytes" "context" "fmt" "path/filepath" @@ -198,3 +199,70 @@ func (bt *BPTree) Sync() error { } return nil } + +// bptreeIterator implement IteratorI +type bptreeIterator struct { + k []byte + v []byte + tx *bbolt.Tx + cursor *bbolt.Cursor + options IteratorOptions +} + +// NewBptreeIterator +func NewBptreeIterator(tx *bbolt.Tx, options IteratorOptions) (*bptreeIterator, error) { + b := tx.Bucket(indexBucketName) + c := b.Cursor() + return &bptreeIterator{ + cursor: c, + options: options, + tx: tx, + }, nil +} + +// Rewind seek the first key in the iterator. +func (bi *bptreeIterator) Rewind() { + if bi.options.Reverse { + bi.k, bi.v = bi.cursor.Last() + } else { + bi.k, bi.v = bi.cursor.First() + } +} + +// Seek move the iterator to the key which is +// greater(less when reverse is true) than or equal to the specified key. +func (bi *bptreeIterator) Seek(key []byte) { + bi.k, bi.v = bi.cursor.Seek(key) + if !bytes.Equal(bi.k, key) && bi.options.Reverse { + bi.k, bi.v = bi.cursor.Prev() + } +} + +// Next moves the iterator to the next key. +func (bi *bptreeIterator) Next() { + if bi.options.Reverse { + bi.k, bi.v = bi.cursor.Prev() + } else { + bi.k, bi.v = bi.cursor.Next() + } +} + +// Key get the current key. +func (bi *bptreeIterator) Key() []byte { + return bi.k +} + +// Value get the current value. +func (ci *bptreeIterator) Value() any { + return ci.v +} + +// Valid returns whether the iterator is exhausted. +func (ci *bptreeIterator) Valid() bool { + return ci.k != nil +} + +// Close the iterator. +func (ci *bptreeIterator) Close() error { + return ci.tx.Rollback() +} diff --git a/bptree_test.go b/bptree_test.go index 9b8142ac..3d17ac13 100644 --- a/bptree_test.go +++ b/bptree_test.go @@ -1,6 +1,7 @@ package lotusdb import ( + "bytes" "os" "path/filepath" "strconv" @@ -303,3 +304,121 @@ func testbptreeSync(t *testing.T, partitionNum int) { err = bt.Sync() assert.Nil(t, err) } + +func Test_bptreeIterator(t *testing.T) { + options := indexOptions{ + indexType: BTree, + dirPath: filepath.Join(os.TempDir(), "bptree-cursorIterator"+strconv.Itoa(1)), + partitionNum: 1, + keyHashFunction: xxhash.Sum64, + } + + err := os.MkdirAll(options.dirPath, os.ModePerm) + assert.Nil(t, err) + defer func() { + _ = os.RemoveAll(options.dirPath) + }() + bt, err := openBTreeIndex(options) + assert.Nil(t, err) + m := map[string]*wal.ChunkPosition{ + "key 0": {SegmentId: 0, BlockNumber: 0, ChunkOffset: 0, ChunkSize: 0}, + "key 1": {SegmentId: 1, BlockNumber: 1, ChunkOffset: 1, ChunkSize: 1}, + "key 2": {SegmentId: 2, BlockNumber: 2, ChunkOffset: 2, ChunkSize: 2}, + } + var keyPositions []*KeyPosition + keyPositions = append(keyPositions, &KeyPosition{ + key: []byte("key 0"), + partition: 0, + position: &wal.ChunkPosition{SegmentId: 0, BlockNumber: 0, ChunkOffset: 0, ChunkSize: 0}, + }, &KeyPosition{ + key: []byte("key 1"), + partition: 0, + position: &wal.ChunkPosition{SegmentId: 1, BlockNumber: 1, ChunkOffset: 1, ChunkSize: 1}, + }, &KeyPosition{ + key: []byte("key 2"), + partition: 0, + position: &wal.ChunkPosition{SegmentId: 2, BlockNumber: 2, ChunkOffset: 2, ChunkSize: 2}, + }, + ) + + err = bt.PutBatch(keyPositions) + assert.Nil(t, err) + + tree := bt.trees[0] + tx, err := tree.Begin(true) + assert.Nil(t, err) + iteratorOptions := IteratorOptions{ + Reverse: false, + } + + itr, err := NewBptreeIterator(tx, iteratorOptions) + assert.Nil(t, err) + var prev []byte + itr.Rewind() + for itr.Valid() { + currKey := itr.Key() + assert.True(t, prev == nil || bytes.Compare(prev, currKey) == -1) + assert.Equal(t, m[string(itr.Key())].Encode(), itr.Value()) + prev = currKey + itr.Next() + } + err = itr.Close() + assert.Nil(t, err) + + tx, err = tree.Begin(true) + assert.Nil(t, err) + iteratorOptions = IteratorOptions{ + Reverse: true, + } + prev = nil + + itr, err = NewBptreeIterator(tx, iteratorOptions) + assert.Nil(t, err) + itr.Rewind() + for itr.Valid() { + currKey := itr.Key() + assert.True(t, prev == nil || bytes.Compare(prev, currKey) == 1) + assert.Equal(t, m[string(itr.Key())].Encode(), itr.Value()) + prev = currKey + itr.Next() + } + itr.Seek([]byte("key 4")) + assert.Equal(t, []byte("key 2"), itr.Key()) + + itr.Seek([]byte("key 2")) + assert.Equal(t, []byte("key 2"), itr.Key()) + + itr.Seek([]byte("aye 2")) + assert.False(t, itr.Valid()) + err = itr.Close() + assert.Nil(t, err) + + tx, err = tree.Begin(true) + assert.Nil(t, err) + iteratorOptions = IteratorOptions{ + Reverse: false, + } + prev = nil + + itr, err = NewBptreeIterator(tx, iteratorOptions) + assert.Nil(t, err) + itr.Rewind() + for itr.Valid() { + currKey := itr.Key() + assert.True(t, prev == nil || bytes.Compare(prev, currKey) == -1) + assert.Equal(t, m[string(itr.Key())].Encode(), itr.Value()) + prev = currKey + itr.Next() + } + + itr.Seek([]byte("key 0")) + assert.Equal(t, []byte("key 0"), itr.Key()) + itr.Seek([]byte("key 4")) + assert.False(t, itr.Valid()) + + itr.Seek([]byte("aye 2")) + assert.Equal(t, []byte("key 0"), itr.Key()) + err = itr.Close() + assert.Nil(t, err) + +} diff --git a/db.go b/db.go index da2353c1..641b5acf 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,7 @@ package lotusdb import ( + "container/heap" "context" "fmt" "io" @@ -17,6 +18,7 @@ import ( "github.com/gofrs/flock" "github.com/rosedblabs/diskhash" "github.com/rosedblabs/wal" + "go.etcd.io/bbolt" "golang.org/x/sync/errgroup" ) @@ -571,3 +573,91 @@ func (db *DB) rewriteValidRecords(walFile *wal.WAL, validRecords []*ValueLogReco } return db.index.PutBatch(positions, matchKeys...) } + +func (db *DB) NewIterator(options IteratorOptions) (*MergeIterator, error) { + db.mu.Lock() + defer func() { + if r := recover(); r != nil { + db.mu.Unlock() + } + }() + itrs := make([]*SingleIter, 0, db.options.PartitionNum+len(db.immuMems)+1) + rank := 0 + txs := make([]*bbolt.Tx, db.options.PartitionNum) + index := db.index.(*BPTree) + for i := 0; i < db.options.PartitionNum; i++ { + tx, err := index.trees[i].Begin(false) + if err != nil { + return nil, err + } + txs[i] = tx + itr, err := NewBptreeIterator( + tx, + options, + ) + if err != nil { + return nil, err + } + itr.Rewind() + // is empty + if !itr.Valid() { + itr.Close() + continue + } + itrs = append(itrs, &SingleIter{ + iType: BptreeItr, + options: options, + rank: rank, + idx: rank, + iter: itr, + }) + + rank++ + } + + for i := 0; i < len(db.immuMems); i++ { + itr, err := NewMemtableIterator(options, db.immuMems[i]) + if err != nil { + return nil, err + } + itr.Rewind() + // is empty + if !itr.Valid() { + itr.Close() + continue + } + itrs = append(itrs, &SingleIter{ + iType: MemItr, + options: options, + rank: rank, + idx: rank, + iter: itr, + }) + rank++ + } + + itr, err := NewMemtableIterator(options, db.activeMem) + if err != nil { + return nil, err + } + itr.Rewind() + if itr.Valid() { + itrs = append(itrs, &SingleIter{ + iType: MemItr, + options: options, + rank: rank, + idx: rank, + iter: itr, + }) + } else { + itr.Close() + } + h := IterHeap(itrs) + heap.Init(&h) + + return &MergeIterator{ + h: h, + itrs: itrs, + db: db, + }, nil +} diff --git a/db_test.go b/db_test.go index 34a2daf6..e1be77d6 100644 --- a/db_test.go +++ b/db_test.go @@ -567,3 +567,188 @@ func TestDBMultiClients(t *testing.T) { wg.Wait() }) } + +func TestDBIterator(t *testing.T) { + options := DefaultOptions + path, err := os.MkdirTemp("", "db-test-iter") + assert.Nil(t, err) + options.DirPath = path + db, err := Open(options) + defer destroyDB(db) + assert.Nil(t, err) + db.immuMems = make([]*memtable, 3) + opts := memtableOptions{ + dirPath: path, + tableId: 0, + memSize: DefaultOptions.MemtableSize, + walBytesPerSync: DefaultOptions.BytesPerSync, + walSync: DefaultBatchOptions.Sync, + walBlockCache: DefaultOptions.BlockCache, + } + for i := 0; i < 3; i++ { + opts.tableId = uint32(i) + db.immuMems[i], err = openMemtable(opts) + assert.Nil(t, err) + } + logRecord_0 := []*LogRecord{ + // 0 + {[]byte("k3"), nil, LogRecordDeleted, 0}, + {[]byte("k1"), []byte("v1"), LogRecordNormal, 0}, + {[]byte("k1"), []byte("v1_1"), LogRecordNormal, 0}, + {[]byte("k2"), []byte("v1_1"), LogRecordNormal, 0}, + } + logRecord_1 := []*LogRecord{ + {[]byte("k1"), []byte("v2_1"), LogRecordNormal, 0}, + {[]byte("k2"), []byte("v2_1"), LogRecordNormal, 0}, + {[]byte("k2"), []byte("v2_2"), LogRecordNormal, 0}, + } + logRecord_2 := []*LogRecord{ + // 2 + {[]byte("k2"), nil, LogRecordDeleted, 0}, + } + logRecord_3 := []*LogRecord{ + {[]byte("k3"), []byte("v3_1"), LogRecordNormal, 0}, + } + + list2Map := func(in []*LogRecord) (out map[string]*LogRecord) { + out = make(map[string]*LogRecord) + for _, v := range in { + out[string(v.Key)] = v + } + return + } + db.immuMems[0].putBatch(list2Map(logRecord_0), 0, nil) + db.immuMems[1].putBatch(list2Map(logRecord_1), 1, nil) + db.immuMems[2].putBatch(list2Map(logRecord_2), 2, nil) + db.activeMem.putBatch(list2Map(logRecord_3), 3, nil) + expectedKey := [][]byte{ + []byte("k1"), + []byte("k3"), + } + expectedVal := [][]byte{ + []byte("v2_1"), + []byte("v3_1"), + } + iter, err := db.NewIterator(IteratorOptions{ + Reverse: false, + }) + assert.Nil(t, err) + var i int + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + + } else { + assert.Equal(t, expectedKey[2-i], iter.Key()) + assert.Equal(t, expectedVal[2-i], iter.Value()) + + } + i++ + iter.Next() + } + + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + + } else { + assert.Equal(t, expectedKey[2-i], iter.Key()) + assert.Equal(t, expectedVal[2-i], iter.Value()) + + } + i++ + iter.Next() + } + err = iter.Close() + assert.Nil(t, err) + + iter, err = db.NewIterator(IteratorOptions{ + Reverse: true, + }) + assert.Nil(t, err) + + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + + } else { + assert.Equal(t, expectedKey[1-i], iter.Key()) + assert.Equal(t, expectedVal[1-i], iter.Value()) + + } + i++ + iter.Next() + } + + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + + } else { + assert.Equal(t, expectedKey[1-i], iter.Key()) + assert.Equal(t, expectedVal[1-i], iter.Value()) + + } + i++ + iter.Next() + } + err = iter.Close() + assert.Nil(t, err) + + for j := 0; j < 3; j++ { + db.flushMemtable(db.immuMems[0]) + iter, err = db.NewIterator(IteratorOptions{ + Reverse: false, + }) + assert.Nil(t, err) + + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + } else { + assert.Equal(t, expectedKey[1-i], iter.Key()) + assert.Equal(t, expectedVal[1-i], iter.Value()) + } + iter.Next() + i++ + } + err = iter.Close() + assert.Nil(t, err) + + iter, err = db.NewIterator(IteratorOptions{ + Reverse: true, + }) + assert.Nil(t, err) + + iter.Rewind() + i = 0 + for iter.Valid() { + if !iter.itrs[0].options.Reverse { + assert.Equal(t, expectedKey[i], iter.Key()) + assert.Equal(t, expectedVal[i], iter.Value()) + } else { + assert.Equal(t, expectedKey[1-i], iter.Key()) + assert.Equal(t, expectedVal[1-i], iter.Value()) + } + iter.Next() + i++ + } + err = iter.Close() + assert.Nil(t, err) + } +} diff --git a/go.mod b/go.mod index d9b4d038..e54b116f 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/golang/glog v1.2.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/klauspost/compress v1.17.4 // indirect + github.com/kr/pretty v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 36bac5c7..a41dd7fa 100644 --- a/go.sum +++ b/go.sum @@ -30,7 +30,10 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/heap.go b/heap.go new file mode 100644 index 00000000..5c69e7f8 --- /dev/null +++ b/heap.go @@ -0,0 +1,75 @@ +package lotusdb + +import ( + "bytes" +) + +type iterType uint8 + +const ( + BptreeItr iterType = iota + MemItr +) + +// SingleIter element used to construct the heap,implementing the container.heap interface. +type SingleIter struct { + iType iterType + options IteratorOptions + rank int // A higher rank indicates newer data. + idx int // idx in heap + iter IteratorI +} + +type IterHeap []*SingleIter + +// Len is the number of elements in the collection. +func (ih IterHeap) Len() int { + return len(ih) +} + +// Less reports whether the element with index i +// must sort before the element with index j. +// +// If both Less(i, j) and Less(j, i) are false, +// then the elements at index i and j are considered equal. +// Sort may place equal elements in any order in the final result, +// while Stable preserves the original input order of equal elements. +// +// Less must describe a transitive ordering: +// - if both Less(i, j) and Less(j, k) are true, then Less(i, k) must be true as well. +// - if both Less(i, j) and Less(j, k) are false, then Less(i, k) must be false as well. +// +// Note that floating-point comparison (the < operator on float32 or float64 values) +// is not a transitive ordering when not-a-number (NaN) values are involved. +// See Float64Slice.Less for a correct implementation for floating-point values. +func (ih IterHeap) Less(i int, j int) bool { + ki, kj := ih[i].iter.Key(), ih[j].iter.Key() + if bytes.Equal(ki, kj) { + return ih[i].rank > ih[j].rank + } + if ih[i].options.Reverse { + return bytes.Compare(ki, kj) == 1 + } else { + return bytes.Compare(ki, kj) == -1 + } +} + +// Swap swaps the elements with indexes i and j. +func (ih IterHeap) Swap(i int, j int) { + ih[i], ih[j] = ih[j], ih[i] + ih[i].idx, ih[j].idx = i, j +} + +// Push add x as element Len(). +func (ih *IterHeap) Push(x any) { + *ih = append(*ih, x.(*SingleIter)) +} + +// Pop remove and return element Len() - 1. +func (ih *IterHeap) Pop() any { + old := *ih + n := len(old) + x := old[n-1] + *ih = old[0 : n-1] + return x +} diff --git a/iterator.go b/iterator.go new file mode 100644 index 00000000..3640c52d --- /dev/null +++ b/iterator.go @@ -0,0 +1,188 @@ +package lotusdb + +import ( + "bytes" + "container/heap" + + "github.com/dgraph-io/badger/v4/y" + "github.com/rosedblabs/wal" +) + +// IteratorI +type IteratorI interface { + // Rewind seek the first key in the iterator. + Rewind() + // Seek move the iterator to the key which is + // greater(less when reverse is true) than or equal to the specified key. + Seek(key []byte) + // Next moves the iterator to the next key. + Next() + // Key get the current key. + Key() []byte + // Value get the current value. + Value() any + // Valid returns whether the iterator is exhausted. + Valid() bool + // Close the iterator. + Close() error +} + +// MergeIterator holds a heap and a set of iterators that implement the IteratorI interface +type MergeIterator struct { + h IterHeap + itrs []*SingleIter // used for rebuilding heap + db *DB +} + +// Rewind seek the first key in the iterator. +func (mi *MergeIterator) Rewind() { + for _, v := range mi.itrs { + v.iter.Rewind() + } + h := IterHeap(mi.itrs) + heap.Init(&h) + mi.h = h +} + +// Seek move the iterator to the key which is +// greater(less when reverse is true) than or equal to the specified key. +func (mi *MergeIterator) Seek(key []byte) { + for i, v := range mi.h { + v.iter.Seek(key) + if !v.iter.Valid() { + heap.Remove(&mi.h, i) + } + } +} + +// cleanKey Remove all unused keys from all iterators. +// If the iterators become empty after clearing, remove them from the heap. +func (mi *MergeIterator) cleanKey(oldKey []byte, rank int) { + defer func() { + if r := recover(); r != nil { + mi.db.mu.Unlock() + } + }() + // delete all key == oldKey && rank < t.rank + copyedItrs := make([]*SingleIter, len(mi.itrs)) + // becouse heap.Remove heap.Fix may alter the order of elements in the slice. + copy(copyedItrs, mi.itrs) + for i := 0; i < len(copyedItrs); i++ { + singleIter := copyedItrs[i] + if singleIter.rank == rank || !singleIter.iter.Valid() { + continue + } + // 这里说明之前还是valid的 + for singleIter.iter.Valid() && + bytes.Equal(singleIter.iter.Key(), oldKey) { + if singleIter.rank > rank { + panic("rank error") + } + singleIter.iter.Next() + } + if !singleIter.iter.Valid() { + heap.Remove(&mi.h, singleIter.idx) + } else { + heap.Fix(&mi.h, singleIter.idx) + } + } +} + +// Next moves the iterator to the next key. +func (mi *MergeIterator) Next() { + // top item + singleIter := mi.h[0] + oldKey := singleIter.iter.Key() + mi.cleanKey(oldKey, singleIter.rank) + if !singleIter.iter.Valid() { + return + } + singleIter.iter.Next() + + if singleIter.iType == MemItr { + // check is deleteKey + for singleIter.iter.Valid() { + if valStruct, ok := singleIter.iter.Value().(y.ValueStruct); ok && valStruct.Meta == LogRecordDeleted { + mi.cleanKey(singleIter.iter.Key(), singleIter.rank) + if !singleIter.iter.Valid() { + return + } + singleIter.iter.Next() + } else { + break + } + } + } + if !singleIter.iter.Valid() { + heap.Remove(&mi.h, 0) + } else { + heap.Fix(&mi.h, 0) + } +} + +// Key get the current key. +func (mi *MergeIterator) Key() []byte { + return mi.h[0].iter.Key() +} + +// Value get the current value. +func (mi *MergeIterator) Value() []byte { + defer func() { + if r := recover(); r != nil { + mi.db.mu.Unlock() + } + }() + singleIter := mi.h[0] + if singleIter.iType == BptreeItr { + keyPos := new(KeyPosition) + keyPos.key = singleIter.iter.Key() + keyPos.partition = uint32(mi.db.vlog.getKeyPartition(singleIter.iter.Key())) + keyPos.position = wal.DecodeChunkPosition(singleIter.iter.Value().([]byte)) + record, err := mi.db.vlog.read(keyPos) + if err != nil { + panic(err) + } + return record.value + } else if singleIter.iType == MemItr { + return singleIter.iter.Value().(y.ValueStruct).Value + } else { + panic("iType not support") + } +} + +// Valid returns whether the iterator is exhausted. +func (mi *MergeIterator) Valid() bool { + if mi.h.Len() == 0 { + return false + } + singleIter := mi.h[0] + if singleIter.iType == MemItr && singleIter.iter.Value().(y.ValueStruct).Meta == LogRecordDeleted { + mi.cleanKey(singleIter.iter.Key(), singleIter.rank) + if !singleIter.iter.Valid() { + return false + } + singleIter.iter.Next() + if singleIter.iter.Valid() { + heap.Fix(&mi.h, 0) + return mi.Valid() + } else { + heap.Remove(&mi.h, 0) + } + } else if singleIter.iType == BptreeItr && !singleIter.iter.Valid() { + heap.Remove(&mi.h, 0) + } + return mi.h.Len() > 0 +} + +// Close the iterator. +func (mi *MergeIterator) Close() error { + for _, v := range mi.itrs { + err := v.iter.Close() + if err != nil { + mi.db.mu.Unlock() + return err + } + } + mi.db.mu.Unlock() + return nil +} diff --git a/memtable.go b/memtable.go index 8f51daf2..62701833 100644 --- a/memtable.go +++ b/memtable.go @@ -180,7 +180,7 @@ func (mt *memtable) putBatch(pendingWrites map[string]*LogRecord, return err } // flush wal if necessary - if options.Sync && !mt.options.walSync { + if options != nil && options.Sync && !mt.options.walSync { if err := mt.wal.Sync(); err != nil { return err } @@ -232,3 +232,53 @@ func (mt *memtable) sync() error { } return nil } + +// memtableIterator implement IteratorI +type memtableIterator struct { + options IteratorOptions + iter *arenaskl.UniIterator +} + +// NewMemtableIterator +func NewMemtableIterator(options IteratorOptions, memtable *memtable) (*memtableIterator, error) { + return &memtableIterator{ + options: options, + iter: memtable.skl.NewUniIterator(options.Reverse), + }, nil +} + +// Rewind seek the first key in the iterator. +func (mi *memtableIterator) Rewind() { + mi.iter.Rewind() +} + +// Seek move the iterator to the key which is +// greater(less when reverse is true) than or equal to the specified key. +func (mi *memtableIterator) Seek(key []byte) { + mi.iter.Seek(y.KeyWithTs(key, 0)) +} + +// Next moves the iterator to the next key. +func (mi *memtableIterator) Next() { + mi.iter.Next() +} + +// Key get the current key. +func (mi *memtableIterator) Key() []byte { + return y.ParseKey(mi.iter.Key()) +} + +// Value get the current value. +func (mi *memtableIterator) Value() any { + return mi.iter.Value() +} + +// Valid returns whether the iterator is exhausted. +func (mi *memtableIterator) Valid() bool { + return mi.iter.Valid() +} + +// Close the iterator. +func (mi *memtableIterator) Close() error { + return mi.iter.Close() +} diff --git a/memtable_test.go b/memtable_test.go index 394a82f9..50dc04c5 100644 --- a/memtable_test.go +++ b/memtable_test.go @@ -1,10 +1,12 @@ package lotusdb import ( + "bytes" "os" "testing" "github.com/bwmarrin/snowflake" + "github.com/dgraph-io/badger/v4/y" "github.com/lotusdblabs/lotusdb/v2/util" "github.com/stretchr/testify/assert" ) @@ -465,3 +467,131 @@ func TestMemtableClose(t *testing.T) { assert.Nil(t, err) }) } + +func TestNewMemtableIterator(t *testing.T) { + path, err := os.MkdirTemp("", "memtable-test-iterator-new") + assert.Nil(t, err) + + defer func() { + _ = os.RemoveAll(path) + }() + + opts := memtableOptions{ + dirPath: path, + tableId: 0, + memSize: DefaultOptions.MemtableSize, + walBytesPerSync: DefaultOptions.BytesPerSync, + walSync: DefaultBatchOptions.Sync, + walBlockCache: DefaultOptions.BlockCache, + } + + table, err := openMemtable(opts) + defer table.close() + assert.Nil(t, err) + + options := IteratorOptions{ + Reverse: false, + } + iter, err := NewMemtableIterator(options, table) + assert.Nil(t, err) + + err = iter.Close() + assert.Nil(t, err) +} + +func Test_memtableIterator(t *testing.T) { + path, err := os.MkdirTemp("", "memtable-test-iterator-rewind") + assert.Nil(t, err) + + defer func() { + _ = os.RemoveAll(path) + }() + + opts := memtableOptions{ + dirPath: path, + tableId: 0, + memSize: DefaultOptions.MemtableSize, + walBytesPerSync: DefaultOptions.BytesPerSync, + walSync: DefaultBatchOptions.Sync, + walBlockCache: DefaultOptions.BlockCache, + } + table, err := openMemtable(opts) + assert.Nil(t, err) + + writeOpts := &WriteOptions{ + Sync: false, + DisableWal: false, + } + node, err := snowflake.NewNode(1) + assert.Nil(t, err) + writeLogs := map[string]*LogRecord{ + "key 0": {Key: []byte("key 0"), Value: []byte("value 0"), Type: LogRecordNormal}, + "key 1": {Key: nil, Value: []byte("value 1"), Type: LogRecordNormal}, + "key 2": {Key: []byte("key 2"), Value: []byte(""), Type: LogRecordNormal}, + } + + err = table.putBatch(writeLogs, node.Generate(), writeOpts) + assert.Nil(t, err) + + iteratorOptions := IteratorOptions{ + Reverse: false, + } + itr, err := NewMemtableIterator(iteratorOptions, table) + assert.Nil(t, err) + var prev []byte + itr.Rewind() + for itr.Valid() { + currKey := itr.Key() + assert.True(t, prev == nil || bytes.Compare(prev, currKey) == -1) + assert.Equal(t, writeLogs[string(currKey)].Value, itr.Value().(y.ValueStruct).Value) + assert.Equal(t, writeLogs[string(currKey)].Type, itr.Value().(y.ValueStruct).Meta) + prev = currKey + itr.Next() + } + err = itr.Close() + assert.Nil(t, err) + + iteratorOptions.Reverse = true + prev = nil + itr, err = NewMemtableIterator(iteratorOptions, table) + assert.Nil(t, err) + itr.Rewind() + for itr.Valid() { + currKey := itr.Key() + assert.True(t, prev == nil || bytes.Compare(prev, currKey) == 1) + prev = currKey + assert.Equal(t, writeLogs[string(currKey)].Value, itr.Value().(y.ValueStruct).Value) + assert.Equal(t, writeLogs[string(currKey)].Type, itr.Value().(y.ValueStruct).Meta) + itr.Next() + } + err = itr.Close() + assert.Nil(t, err) + + iteratorOptions.Reverse = false + itr, err = NewMemtableIterator(iteratorOptions, table) + assert.Nil(t, err) + itr.Seek([]byte("key 0")) + assert.Equal(t, []byte("key 0"), itr.Key()) + itr.Seek([]byte("key 4")) + assert.False(t, itr.Valid()) + + itr.Seek([]byte("aye 2")) + assert.Equal(t, []byte("key 0"), itr.Key()) + err = itr.Close() + assert.Nil(t, err) + + iteratorOptions.Reverse = true + itr, err = NewMemtableIterator(iteratorOptions, table) + assert.Nil(t, err) + itr.Seek([]byte("key 4")) + assert.Equal(t, []byte("key 2"), itr.Key()) + + itr.Seek([]byte("key 2")) + assert.Equal(t, []byte("key 2"), itr.Key()) + + itr.Seek([]byte("aye 2")) + assert.False(t, itr.Valid()) + + err = itr.Close() + assert.Nil(t, err) +}