Skip to content

Commit

Permalink
fix: Fixes memiterator (#21775)
Browse files Browse the repository at this point in the history
Co-authored-by: marbar3778 <[email protected]>
  • Loading branch information
alpe and tac0turtle authored Sep 20, 2024
1 parent 8d248f0 commit aa90bb4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
7 changes: 5 additions & 2 deletions server/v2/stf/branch/changeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type memIterator struct {
}

// newMemIterator creates a new memory iterator for a given range of keys in a B-tree.
// The iterator starts at the specified start key and ends at the specified end key.
// The iterator creates a copy then starts at the specified start key and ends at the specified end key.
// The `tree` parameter is the B-tree to iterate over.
// The `ascending` parameter determines the direction of iteration.
// If `ascending` is true, the iterator will iterate in ascending order.
Expand All @@ -111,7 +111,7 @@ type memIterator struct {
// The `valid` field of the iterator indicates whether the iterator is positioned at a valid key.
// The `start` and `end` fields of the iterator store the start and end keys respectively.
func newMemIterator(start, end []byte, tree *btree.BTreeG[item], ascending bool) *memIterator {
iter := tree.Iter()
iter := tree.Copy().Iter()
var valid bool
if ascending {
if start != nil {
Expand Down Expand Up @@ -207,6 +207,9 @@ func (mi *memIterator) keyInRange(key []byte) bool {
if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 {
return false
}
if !mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 {
return false
}
return true
}

Expand Down
61 changes: 60 additions & 1 deletion server/v2/stf/branch/changeset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"testing"
)

func Test_memIterator(t *testing.T) {
func TestMemIteratorWithWriteToRebalance(t *testing.T) {
t.Run("iter is invalid after close", func(t *testing.T) {
cs := newChangeSet()
for i := byte(0); i < 32; i++ {
Expand All @@ -26,3 +26,62 @@ func Test_memIterator(t *testing.T) {
}
})
}

func TestKeyInRange(t *testing.T) {
specs := map[string]struct {
mi *memIterator
src []byte
exp bool
}{
"equal start": {
mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}},
src: []byte{0},
exp: true,
},
"equal end": {
mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}},
src: []byte{2},
exp: false,
},
"between": {
mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}},
src: []byte{1},
exp: true,
},
"equal start - open end": {
mi: &memIterator{ascending: true, start: []byte{0}},
src: []byte{0},
exp: true,
},
"greater start - open end": {
mi: &memIterator{ascending: true, start: []byte{0}},
src: []byte{2},
exp: true,
},
"equal end - open start": {
mi: &memIterator{ascending: true, end: []byte{2}},
src: []byte{2},
exp: false,
},
"smaller end - open start": {
mi: &memIterator{ascending: true, end: []byte{2}},
src: []byte{1},
exp: true,
},
}
for name, spec := range specs {
for _, asc := range []bool{true, false} {
order := "asc_"
if !asc {
order = "desc_"
}
t.Run(order+name, func(t *testing.T) {
spec.mi.ascending = asc
got := spec.mi.keyInRange(spec.src)
if spec.exp != got {
t.Errorf("expected %v, got %v", spec.exp, got)
}
})
}
}
}

0 comments on commit aa90bb4

Please sign in to comment.