diff --git a/containers/containers.go b/containers/containers.go index a512a3cb..1c016aa6 100644 --- a/containers/containers.go +++ b/containers/containers.go @@ -13,24 +13,39 @@ // Serialization provides serializers (marshalers) and deserializers (unmarshalers). package containers -import "github.com/emirpasic/gods/utils" +import ( + "cmp" + "slices" + + "github.com/emirpasic/gods/v2/utils" +) // Container is base interface that all data structures implement. -type Container interface { +type Container[T any] interface { Empty() bool Size() int Clear() - Values() []interface{} + Values() []T String() string } // GetSortedValues returns sorted container's elements with respect to the passed comparator. // Does not affect the ordering of elements within the container. -func GetSortedValues(container Container, comparator utils.Comparator) []interface{} { +func GetSortedValues[T cmp.Ordered](container Container[T]) []T { + values := container.Values() + if len(values) < 2 { + return values + } + slices.Sort(values) + return values +} + +// GetSortedValuesFunc is the equivalent of GetSortedValues for containers of values that are not ordered. +func GetSortedValuesFunc[T any](container Container[T], comparator utils.Comparator[T]) []T { values := container.Values() if len(values) < 2 { return values } - utils.Sort(values, comparator) + slices.SortFunc(values, comparator) return values } diff --git a/containers/containers_test.go b/containers/containers_test.go index e92d123d..06763b43 100644 --- a/containers/containers_test.go +++ b/containers/containers_test.go @@ -7,34 +7,34 @@ package containers import ( + "cmp" "fmt" - "github.com/emirpasic/gods/utils" "strings" "testing" ) // For testing purposes -type ContainerTest struct { - values []interface{} +type ContainerTest[T any] struct { + values []T } -func (container ContainerTest) Empty() bool { +func (container ContainerTest[T]) Empty() bool { return len(container.values) == 0 } -func (container ContainerTest) Size() int { +func (container ContainerTest[T]) Size() int { return len(container.values) } -func (container ContainerTest) Clear() { - container.values = []interface{}{} +func (container ContainerTest[T]) Clear() { + container.values = []T{} } -func (container ContainerTest) Values() []interface{} { +func (container ContainerTest[T]) Values() []T { return container.values } -func (container ContainerTest) String() string { +func (container ContainerTest[T]) String() string { str := "ContainerTest\n" var values []string for _, value := range container.values { @@ -45,24 +45,32 @@ func (container ContainerTest) String() string { } func TestGetSortedValuesInts(t *testing.T) { - container := ContainerTest{} - GetSortedValues(container, utils.IntComparator) - container.values = []interface{}{5, 1, 3, 2, 4} - values := GetSortedValues(container, utils.IntComparator) + container := ContainerTest[int]{} + GetSortedValues(container) + container.values = []int{5, 1, 3, 2, 4} + values := GetSortedValues(container) for i := 1; i < container.Size(); i++ { - if values[i-1].(int) > values[i].(int) { + if values[i-1] > values[i] { t.Errorf("Not sorted!") } } } -func TestGetSortedValuesStrings(t *testing.T) { - container := ContainerTest{} - GetSortedValues(container, utils.StringComparator) - container.values = []interface{}{"g", "a", "d", "e", "f", "c", "b"} - values := GetSortedValues(container, utils.StringComparator) +type NotInt struct { + i int +} + +func TestGetSortedValuesNotInts(t *testing.T) { + container := ContainerTest[NotInt]{} + GetSortedValuesFunc(container, func(x, y NotInt) int { + return cmp.Compare(x.i, y.i) + }) + container.values = []NotInt{{5}, {1}, {3}, {2}, {4}} + values := GetSortedValuesFunc(container, func(x, y NotInt) int { + return cmp.Compare(x.i, y.i) + }) for i := 1; i < container.Size(); i++ { - if values[i-1].(string) > values[i].(string) { + if values[i-1].i > values[i].i { t.Errorf("Not sorted!") } } diff --git a/containers/enumerable.go b/containers/enumerable.go index 70660054..1121388c 100644 --- a/containers/enumerable.go +++ b/containers/enumerable.go @@ -5,9 +5,9 @@ package containers // EnumerableWithIndex provides functions for ordered containers whose values can be fetched by an index. -type EnumerableWithIndex interface { +type EnumerableWithIndex[T any] interface { // Each calls the given function once for each element, passing that element's index and value. - Each(func(index int, value interface{})) + Each(func(index int, value T)) // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. @@ -18,22 +18,22 @@ type EnumerableWithIndex interface { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. - Any(func(index int, value interface{}) bool) bool + Any(func(index int, value T) bool) bool // All passes each element of the container to the given function and // returns true if the function returns true for all elements. - All(func(index int, value interface{}) bool) bool + All(func(index int, value T) bool) bool // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. - Find(func(index int, value interface{}) bool) (int, interface{}) + Find(func(index int, value T) bool) (int, T) } // EnumerableWithKey provides functions for ordered containers whose values whose elements are key/value pairs. -type EnumerableWithKey interface { +type EnumerableWithKey[K, V any] interface { // Each calls the given function once for each element, passing that element's key and value. - Each(func(key interface{}, value interface{})) + Each(func(key K, value V)) // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. @@ -44,14 +44,14 @@ type EnumerableWithKey interface { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. - Any(func(key interface{}, value interface{}) bool) bool + Any(func(key K, value V) bool) bool // All passes each element of the container to the given function and // returns true if the function returns true for all elements. - All(func(key interface{}, value interface{}) bool) bool + All(func(key K, value V) bool) bool // Find passes each element of the container to the given function and returns // the first (key,value) for which the function is true or nil,nil otherwise if no element // matches the criteria. - Find(func(key interface{}, value interface{}) bool) (interface{}, interface{}) + Find(func(key K, value V) bool) (K, V) } diff --git a/containers/iterator.go b/containers/iterator.go index 73994ec8..68f4b5d5 100644 --- a/containers/iterator.go +++ b/containers/iterator.go @@ -5,7 +5,7 @@ package containers // IteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index. -type IteratorWithIndex interface { +type IteratorWithIndex[T any] interface { // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. @@ -14,7 +14,7 @@ type IteratorWithIndex interface { // Value returns the current element's value. // Does not modify the state of the iterator. - Value() interface{} + Value() T // Index returns the current element's index. // Does not modify the state of the iterator. @@ -33,11 +33,11 @@ type IteratorWithIndex interface { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. - NextTo(func(index int, value interface{}) bool) bool + NextTo(func(index int, value T) bool) bool } // IteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs. -type IteratorWithKey interface { +type IteratorWithKey[K, V any] interface { // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. @@ -46,11 +46,11 @@ type IteratorWithKey interface { // Value returns the current element's value. // Does not modify the state of the iterator. - Value() interface{} + Value() V // Key returns the current element's key. // Does not modify the state of the iterator. - Key() interface{} + Key() K // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. @@ -65,19 +65,19 @@ type IteratorWithKey interface { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. - NextTo(func(key interface{}, value interface{}) bool) bool + NextTo(func(key K, value V) bool) bool } // ReverseIteratorWithIndex is stateful iterator for ordered containers whose values can be fetched by an index. // // Essentially it is the same as IteratorWithIndex, but provides additional: // -// Prev() function to enable traversal in reverse +// # Prev() function to enable traversal in reverse // // Last() function to move the iterator to the last element. // // End() function to move the iterator past the last element (one-past-the-end). -type ReverseIteratorWithIndex interface { +type ReverseIteratorWithIndex[T any] interface { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. @@ -96,19 +96,19 @@ type ReverseIteratorWithIndex interface { // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. - PrevTo(func(index int, value interface{}) bool) bool + PrevTo(func(index int, value T) bool) bool - IteratorWithIndex + IteratorWithIndex[T] } // ReverseIteratorWithKey is a stateful iterator for ordered containers whose elements are key value pairs. // // Essentially it is the same as IteratorWithKey, but provides additional: // -// Prev() function to enable traversal in reverse +// # Prev() function to enable traversal in reverse // // Last() function to move the iterator to the last element. -type ReverseIteratorWithKey interface { +type ReverseIteratorWithKey[K, V any] interface { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. @@ -127,7 +127,7 @@ type ReverseIteratorWithKey interface { // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. - PrevTo(func(key interface{}, value interface{}) bool) bool + PrevTo(func(key K, value V) bool) bool - IteratorWithKey + IteratorWithKey[K, V] } diff --git a/examples/arraylist/arraylist.go b/examples/arraylist/arraylist.go index 4d4fbd90..fb98d64c 100644 --- a/examples/arraylist/arraylist.go +++ b/examples/arraylist/arraylist.go @@ -5,16 +5,17 @@ package main import ( - "github.com/emirpasic/gods/lists/arraylist" - "github.com/emirpasic/gods/utils" + "cmp" + + "github.com/emirpasic/gods/v2/lists/arraylist" ) // ArrayListExample to demonstrate basic usage of ArrayList func main() { - list := arraylist.New() + list := arraylist.New[string]() list.Add("a") // ["a"] list.Add("c", "b") // ["a","c","b"] - list.Sort(utils.StringComparator) // ["a","b","c"] + list.Sort(cmp.Compare[string]) // ["a","b","c"] _, _ = list.Get(0) // "a",true _, _ = list.Get(100) // nil,false _ = list.Contains("a", "b", "c") // true diff --git a/examples/arrayqueue/arrayqqueue.go b/examples/arrayqueue/arrayqqueue.go index 13b88187..bc4c4584 100644 --- a/examples/arrayqueue/arrayqqueue.go +++ b/examples/arrayqueue/arrayqqueue.go @@ -4,11 +4,11 @@ package main -import aq "github.com/emirpasic/gods/queues/arrayqueue" +import aq "github.com/emirpasic/gods/v2/queues/arrayqueue" // ArrayQueueExample to demonstrate basic usage of ArrayQueue func main() { - queue := aq.New() // empty + queue := aq.New[int]() // empty queue.Enqueue(1) // 1 queue.Enqueue(2) // 1, 2 _ = queue.Values() // 1, 2 (FIFO order) diff --git a/examples/arraystack/arraystack.go b/examples/arraystack/arraystack.go index aa06eaf7..1a54f370 100644 --- a/examples/arraystack/arraystack.go +++ b/examples/arraystack/arraystack.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/stacks/arraystack" +import "github.com/emirpasic/gods/v2/stacks/arraystack" // ArrayStackExample to demonstrate basic usage of ArrayStack func main() { - stack := arraystack.New() // empty - stack.Push(1) // 1 - stack.Push(2) // 1, 2 - stack.Values() // 2, 1 (LIFO order) - _, _ = stack.Peek() // 2,true - _, _ = stack.Pop() // 2, true - _, _ = stack.Pop() // 1, true - _, _ = stack.Pop() // nil, false (nothing to pop) - stack.Push(1) // 1 - stack.Clear() // empty - stack.Empty() // true - stack.Size() // 0 + stack := arraystack.New[int]() // empty + stack.Push(1) // 1 + stack.Push(2) // 1, 2 + stack.Values() // 2, 1 (LIFO order) + _, _ = stack.Peek() // 2,true + _, _ = stack.Pop() // 2, true + _, _ = stack.Pop() // 1, true + _, _ = stack.Pop() // nil, false (nothing to pop) + stack.Push(1) // 1 + stack.Clear() // empty + stack.Empty() // true + stack.Size() // 0 } diff --git a/examples/avltree/avltree.go b/examples/avltree/avltree.go index b6d1aabc..5499ae0d 100644 --- a/examples/avltree/avltree.go +++ b/examples/avltree/avltree.go @@ -6,12 +6,13 @@ package main import ( "fmt" - avl "github.com/emirpasic/gods/trees/avltree" + + avl "github.com/emirpasic/gods/v2/trees/avltree" ) // AVLTreeExample to demonstrate basic usage of AVLTree func main() { - tree := avl.NewWithIntComparator() // empty(keys are of type int) + tree := avl.New[int, string]() // empty(keys are of type int) tree.Put(1, "x") // 1->x tree.Put(2, "b") // 1->x, 2->b (in order) diff --git a/examples/binaryheap/binaryheap.go b/examples/binaryheap/binaryheap.go index 4bc93813..ad680389 100644 --- a/examples/binaryheap/binaryheap.go +++ b/examples/binaryheap/binaryheap.go @@ -5,32 +5,33 @@ package main import ( - "github.com/emirpasic/gods/trees/binaryheap" - "github.com/emirpasic/gods/utils" + "cmp" + + "github.com/emirpasic/gods/v2/trees/binaryheap" ) // BinaryHeapExample to demonstrate basic usage of BinaryHeap func main() { // Min-heap - heap := binaryheap.NewWithIntComparator() // empty (min-heap) - heap.Push(2) // 2 - heap.Push(3) // 2, 3 - heap.Push(1) // 1, 3, 2 - heap.Values() // 1, 3, 2 - _, _ = heap.Peek() // 1,true - _, _ = heap.Pop() // 1, true - _, _ = heap.Pop() // 2, true - _, _ = heap.Pop() // 3, true - _, _ = heap.Pop() // nil, false (nothing to pop) - heap.Push(1) // 1 - heap.Clear() // empty - heap.Empty() // true - heap.Size() // 0 + heap := binaryheap.New[int]() // empty (min-heap) + heap.Push(2) // 2 + heap.Push(3) // 2, 3 + heap.Push(1) // 1, 3, 2 + heap.Values() // 1, 3, 2 + _, _ = heap.Peek() // 1,true + _, _ = heap.Pop() // 1, true + _, _ = heap.Pop() // 2, true + _, _ = heap.Pop() // 3, true + _, _ = heap.Pop() // nil, false (nothing to pop) + heap.Push(1) // 1 + heap.Clear() // empty + heap.Empty() // true + heap.Size() // 0 // Max-heap - inverseIntComparator := func(a, b interface{}) int { - return -utils.IntComparator(a, b) + inverseIntComparator := func(a, b int) int { + return -cmp.Compare(a, b) } heap = binaryheap.NewWith(inverseIntComparator) // empty (min-heap) heap.Push(2) // 2 diff --git a/examples/btree/btree.go b/examples/btree/btree.go index ea61b03f..fdf87d11 100644 --- a/examples/btree/btree.go +++ b/examples/btree/btree.go @@ -6,12 +6,13 @@ package main import ( "fmt" - "github.com/emirpasic/gods/trees/btree" + + "github.com/emirpasic/gods/v2/trees/btree" ) // BTreeExample to demonstrate basic usage of BTree func main() { - tree := btree.NewWithIntComparator(3) // empty (keys are of type int) + tree := btree.New[int, string](3) // empty (keys are of type int) tree.Put(1, "x") // 1->x tree.Put(2, "b") // 1->x, 2->b (in order) diff --git a/examples/circularbuffer/circularbuffer.go b/examples/circularbuffer/circularbuffer.go index 3bd5f2ae..1d230033 100644 --- a/examples/circularbuffer/circularbuffer.go +++ b/examples/circularbuffer/circularbuffer.go @@ -4,23 +4,23 @@ package main -import cb "github.com/emirpasic/gods/queues/circularbuffer" +import cb "github.com/emirpasic/gods/v2/queues/circularbuffer" // CircularBufferExample to demonstrate basic usage of CircularBuffer func main() { - queue := cb.New(3) // empty (max size is 3) - queue.Enqueue(1) // 1 - queue.Enqueue(2) // 1, 2 - queue.Enqueue(3) // 1, 2, 3 - _ = queue.Values() // 1, 2, 3 - queue.Enqueue(3) // 4, 2, 3 - _, _ = queue.Peek() // 4,true - _, _ = queue.Dequeue() // 4, true - _, _ = queue.Dequeue() // 2, true - _, _ = queue.Dequeue() // 3, true - _, _ = queue.Dequeue() // nil, false (nothing to deque) - queue.Enqueue(1) // 1 - queue.Clear() // empty - queue.Empty() // true - _ = queue.Size() // 0 + queue := cb.New[int](3) // empty (max size is 3) + queue.Enqueue(1) // 1 + queue.Enqueue(2) // 1, 2 + queue.Enqueue(3) // 1, 2, 3 + _ = queue.Values() // 1, 2, 3 + queue.Enqueue(3) // 4, 2, 3 + _, _ = queue.Peek() // 4,true + _, _ = queue.Dequeue() // 4, true + _, _ = queue.Dequeue() // 2, true + _, _ = queue.Dequeue() // 3, true + _, _ = queue.Dequeue() // nil, false (nothing to deque) + queue.Enqueue(1) // 1 + queue.Clear() // empty + queue.Empty() // true + _ = queue.Size() // 0 } diff --git a/examples/customcomparator/customcomparator.go b/examples/customcomparator/customcomparator.go index b61d9698..0a5d9572 100644 --- a/examples/customcomparator/customcomparator.go +++ b/examples/customcomparator/customcomparator.go @@ -6,7 +6,8 @@ package main import ( "fmt" - "github.com/emirpasic/gods/sets/treeset" + + "github.com/emirpasic/gods/v2/sets/treeset" ) // User model (id and name) @@ -16,16 +17,12 @@ type User struct { } // Comparator function (sort by IDs) -func byID(a, b interface{}) int { - - // Type assertion, program will panic if this is not respected - c1 := a.(User) - c2 := b.(User) +func byID(a, b User) int { switch { - case c1.id > c2.id: + case a.id > b.id: return 1 - case c1.id < c2.id: + case a.id < b.id: return -1 default: return 0 diff --git a/examples/doublylinkedlist/doublylinkedlist.go b/examples/doublylinkedlist/doublylinkedlist.go index 99ec995c..53b4a897 100644 --- a/examples/doublylinkedlist/doublylinkedlist.go +++ b/examples/doublylinkedlist/doublylinkedlist.go @@ -5,17 +5,18 @@ package main import ( - dll "github.com/emirpasic/gods/lists/doublylinkedlist" - "github.com/emirpasic/gods/utils" + "cmp" + + dll "github.com/emirpasic/gods/v2/lists/doublylinkedlist" ) // DoublyLinkedListExample to demonstrate basic usage of DoublyLinkedList func main() { - list := dll.New() + list := dll.New[string]() list.Add("a") // ["a"] list.Append("b") // ["a","b"] (same as Add()) list.Prepend("c") // ["c","a","b"] - list.Sort(utils.StringComparator) // ["a","b","c"] + list.Sort(cmp.Compare[string]) // ["a","b","c"] _, _ = list.Get(0) // "a",true _, _ = list.Get(100) // nil,false _ = list.Contains("a", "b", "c") // true diff --git a/examples/enumerablewithindex/enumerablewithindex.go b/examples/enumerablewithindex/enumerablewithindex.go index 95459117..f3dd7419 100644 --- a/examples/enumerablewithindex/enumerablewithindex.go +++ b/examples/enumerablewithindex/enumerablewithindex.go @@ -6,12 +6,13 @@ package main import ( "fmt" - "github.com/emirpasic/gods/sets/treeset" + + "github.com/emirpasic/gods/v2/sets/treeset" ) -func printSet(txt string, set *treeset.Set) { +func printSet(txt string, set *treeset.Set[int]) { fmt.Print(txt, "[ ") - set.Each(func(index int, value interface{}) { + set.Each(func(index int, value int) { fmt.Print(value, " ") }) fmt.Println("]") @@ -19,41 +20,41 @@ func printSet(txt string, set *treeset.Set) { // EnumerableWithIndexExample to demonstrate basic usage of EnumerableWithIndex func main() { - set := treeset.NewWithIntComparator() + set := treeset.New[int]() set.Add(2, 3, 4, 2, 5, 6, 7, 8) printSet("Initial", set) // [ 2 3 4 5 6 7 8 ] - even := set.Select(func(index int, value interface{}) bool { - return value.(int)%2 == 0 + even := set.Select(func(index int, value int) bool { + return value%2 == 0 }) printSet("Even numbers", even) // [ 2 4 6 8 ] - foundIndex, foundValue := set.Find(func(index int, value interface{}) bool { - return value.(int)%2 == 0 && value.(int)%3 == 0 + foundIndex, foundValue := set.Find(func(index int, value int) bool { + return value%2 == 0 && value%3 == 0 }) if foundIndex != -1 { fmt.Println("Number divisible by 2 and 3 found is", foundValue, "at index", foundIndex) // value: 6, index: 4 } - square := set.Map(func(index int, value interface{}) interface{} { - return value.(int) * value.(int) + square := set.Map(func(index int, value int) int { + return value * value }) printSet("Numbers squared", square) // [ 4 9 16 25 36 49 64 ] - bigger := set.Any(func(index int, value interface{}) bool { - return value.(int) > 5 + bigger := set.Any(func(index int, value int) bool { + return value > 5 }) fmt.Println("Set contains a number bigger than 5 is ", bigger) // true - positive := set.All(func(index int, value interface{}) bool { - return value.(int) > 0 + positive := set.All(func(index int, value int) bool { + return value > 0 }) fmt.Println("All numbers are positive is", positive) // true - evenNumbersSquared := set.Select(func(index int, value interface{}) bool { - return value.(int)%2 == 0 - }).Map(func(index int, value interface{}) interface{} { - return value.(int) * value.(int) + evenNumbersSquared := set.Select(func(index int, value int) bool { + return value%2 == 0 + }).Map(func(index int, value int) int { + return value * value }) printSet("Chaining", evenNumbersSquared) // [ 4 16 36 64 ] } diff --git a/examples/enumerablewithkey/enumerablewithkey.go b/examples/enumerablewithkey/enumerablewithkey.go index 7f05040d..0bfa4c08 100644 --- a/examples/enumerablewithkey/enumerablewithkey.go +++ b/examples/enumerablewithkey/enumerablewithkey.go @@ -6,12 +6,13 @@ package main import ( "fmt" - "github.com/emirpasic/gods/maps/treemap" + + "github.com/emirpasic/gods/v2/maps/treemap" ) -func printMap(txt string, m *treemap.Map) { +func printMap(txt string, m *treemap.Map[string, int]) { fmt.Print(txt, " { ") - m.Each(func(key interface{}, value interface{}) { + m.Each(func(key string, value int) { fmt.Print(key, ":", value, " ") }) fmt.Println("}") @@ -19,7 +20,7 @@ func printMap(txt string, m *treemap.Map) { // EunumerableWithKeyExample to demonstrate basic usage of EunumerableWithKey func main() { - m := treemap.NewWithStringComparator() + m := treemap.New[string, int]() m.Put("g", 7) m.Put("f", 6) m.Put("e", 5) @@ -29,37 +30,37 @@ func main() { m.Put("a", 1) printMap("Initial", m) // { a:1 b:2 c:3 d:4 e:5 f:6 g:7 } - even := m.Select(func(key interface{}, value interface{}) bool { - return value.(int)%2 == 0 + even := m.Select(func(key string, value int) bool { + return value%2 == 0 }) printMap("Elements with even values", even) // { b:2 d:4 f:6 } - foundKey, foundValue := m.Find(func(key interface{}, value interface{}) bool { - return value.(int)%2 == 0 && value.(int)%3 == 0 + foundKey, foundValue := m.Find(func(key string, value int) bool { + return value%2 == 0 && value%3 == 0 }) - if foundKey != nil { + if foundKey != "" { fmt.Println("Element with value divisible by 2 and 3 found is", foundValue, "with key", foundKey) // value: 6, index: 4 } - square := m.Map(func(key interface{}, value interface{}) (interface{}, interface{}) { - return key.(string) + key.(string), value.(int) * value.(int) + square := m.Map(func(key string, value int) (string, int) { + return key + key, value * value }) printMap("Elements' values squared and letters duplicated", square) // { aa:1 bb:4 cc:9 dd:16 ee:25 ff:36 gg:49 } - bigger := m.Any(func(key interface{}, value interface{}) bool { - return value.(int) > 5 + bigger := m.Any(func(key string, value int) bool { + return value > 5 }) fmt.Println("Map contains element whose value is bigger than 5 is", bigger) // true - positive := m.All(func(key interface{}, value interface{}) bool { - return value.(int) > 0 + positive := m.All(func(key string, value int) bool { + return value > 0 }) fmt.Println("All map's elements have positive values is", positive) // true - evenNumbersSquared := m.Select(func(key interface{}, value interface{}) bool { - return value.(int)%2 == 0 - }).Map(func(key interface{}, value interface{}) (interface{}, interface{}) { - return key, value.(int) * value.(int) + evenNumbersSquared := m.Select(func(key string, value int) bool { + return value%2 == 0 + }).Map(func(key string, value int) (string, int) { + return key, value * value }) printMap("Chaining", evenNumbersSquared) // { b:4 d:16 f:36 } } diff --git a/examples/godsort/godsort.go b/examples/godsort/godsort.go deleted file mode 100644 index 9f0bd2ac..00000000 --- a/examples/godsort/godsort.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2015, Emir Pasic. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import "github.com/emirpasic/gods/utils" - -// SortExample to demonstrate basic usage of basic sort -func main() { - strings := []interface{}{} // [] - strings = append(strings, "d") // ["d"] - strings = append(strings, "a") // ["d","a"] - strings = append(strings, "b") // ["d","a",b" - strings = append(strings, "c") // ["d","a",b","c"] - utils.Sort(strings, utils.StringComparator) // ["a","b","c","d"] -} diff --git a/examples/hashbidimap/hashbidimap.go b/examples/hashbidimap/hashbidimap.go index 26350b8c..cc297668 100644 --- a/examples/hashbidimap/hashbidimap.go +++ b/examples/hashbidimap/hashbidimap.go @@ -4,22 +4,22 @@ package main -import "github.com/emirpasic/gods/maps/hashbidimap" +import "github.com/emirpasic/gods/v2/maps/hashbidimap" // HashBidiMapExample to demonstrate basic usage of HashMap func main() { - m := hashbidimap.New() // empty - m.Put(1, "x") // 1->x - m.Put(3, "b") // 1->x, 3->b (random order) - m.Put(1, "a") // 1->a, 3->b (random order) - m.Put(2, "b") // 1->a, 2->b (random order) - _, _ = m.GetKey("a") // 1, true - _, _ = m.Get(2) // b, true - _, _ = m.Get(3) // nil, false - _ = m.Values() // []interface {}{"a", "b"} (random order) - _ = m.Keys() // []interface {}{1, 2} (random order) - m.Remove(1) // 2->b - m.Clear() // empty - m.Empty() // true - m.Size() // 0 + m := hashbidimap.New[int, string]() // empty + m.Put(1, "x") // 1->x + m.Put(3, "b") // 1->x, 3->b (random order) + m.Put(1, "a") // 1->a, 3->b (random order) + m.Put(2, "b") // 1->a, 2->b (random order) + _, _ = m.GetKey("a") // 1, true + _, _ = m.Get(2) // b, true + _, _ = m.Get(3) // nil, false + _ = m.Values() // []interface {}{"a", "b"} (random order) + _ = m.Keys() // []interface {}{1, 2} (random order) + m.Remove(1) // 2->b + m.Clear() // empty + m.Empty() // true + m.Size() // 0 } diff --git a/examples/hashmap/hashmap.go b/examples/hashmap/hashmap.go index 2fda79e6..4be346d5 100644 --- a/examples/hashmap/hashmap.go +++ b/examples/hashmap/hashmap.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/maps/hashmap" +import "github.com/emirpasic/gods/v2/maps/hashmap" // HashMapExample to demonstrate basic usage of HashMap func main() { - m := hashmap.New() // empty - m.Put(1, "x") // 1->x - m.Put(2, "b") // 2->b, 1->x (random order) - m.Put(1, "a") // 2->b, 1->a (random order) - _, _ = m.Get(2) // b, true - _, _ = m.Get(3) // nil, false - _ = m.Values() // []interface {}{"b", "a"} (random order) - _ = m.Keys() // []interface {}{1, 2} (random order) - m.Remove(1) // 2->b - m.Clear() // empty - m.Empty() // true - m.Size() // 0 + m := hashmap.New[int, string]() // empty + m.Put(1, "x") // 1->x + m.Put(2, "b") // 2->b, 1->x (random order) + m.Put(1, "a") // 2->b, 1->a (random order) + _, _ = m.Get(2) // b, true + _, _ = m.Get(3) // nil, false + _ = m.Values() // []interface {}{"b", "a"} (random order) + _ = m.Keys() // []interface {}{1, 2} (random order) + m.Remove(1) // 2->b + m.Clear() // empty + m.Empty() // true + m.Size() // 0 } diff --git a/examples/hashset/hashset.go b/examples/hashset/hashset.go index 6c366e51..eccff89a 100644 --- a/examples/hashset/hashset.go +++ b/examples/hashset/hashset.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/sets/hashset" +import "github.com/emirpasic/gods/v2/sets/hashset" // HashSetExample to demonstrate basic usage of HashSet func main() { - set := hashset.New() // empty (keys are of type int) - set.Add(1) // 1 - set.Add(2, 2, 3, 4, 5) // 3, 1, 2, 4, 5 (random order, duplicates ignored) - set.Remove(4) // 5, 3, 2, 1 (random order) - set.Remove(2, 3) // 1, 5 (random order) - set.Contains(1) // true - set.Contains(1, 5) // true - set.Contains(1, 6) // false - _ = set.Values() // []int{5,1} (random order) - set.Clear() // empty - set.Empty() // true - set.Size() // 0 + set := hashset.New[int]() // empty (keys are of type int) + set.Add(1) // 1 + set.Add(2, 2, 3, 4, 5) // 3, 1, 2, 4, 5 (random order, duplicates ignored) + set.Remove(4) // 5, 3, 2, 1 (random order) + set.Remove(2, 3) // 1, 5 (random order) + set.Contains(1) // true + set.Contains(1, 5) // true + set.Contains(1, 6) // false + _ = set.Values() // []int{5,1} (random order) + set.Clear() // empty + set.Empty() // true + set.Size() // 0 } diff --git a/examples/iteratorwithindex/iteratorwithindex.go b/examples/iteratorwithindex/iteratorwithindex.go index 4cbc87e0..d68cf283 100644 --- a/examples/iteratorwithindex/iteratorwithindex.go +++ b/examples/iteratorwithindex/iteratorwithindex.go @@ -6,13 +6,14 @@ package main import ( "fmt" - "github.com/emirpasic/gods/sets/treeset" "strings" + + "github.com/emirpasic/gods/v2/sets/treeset" ) // IteratorWithIndexExample to demonstrate basic usage of IteratorWithIndex func main() { - set := treeset.NewWithStringComparator() + set := treeset.New[string]() set.Add("a", "b", "c") it := set.Iterator() @@ -51,8 +52,8 @@ func main() { } // Seek element starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } it.Begin() diff --git a/examples/iteratorwithkey/iteratorwithkey.go b/examples/iteratorwithkey/iteratorwithkey.go index 521e43b2..32396082 100644 --- a/examples/iteratorwithkey/iteratorwithkey.go +++ b/examples/iteratorwithkey/iteratorwithkey.go @@ -6,13 +6,14 @@ package main import ( "fmt" - "github.com/emirpasic/gods/maps/treemap" "strings" + + "github.com/emirpasic/gods/v2/maps/treemap" ) // IteratorWithKeyExample to demonstrate basic usage of IteratorWithKey func main() { - m := treemap.NewWithIntComparator() + m := treemap.New[int, string]() m.Put(0, "a") m.Put(1, "b") m.Put(2, "c") @@ -53,8 +54,8 @@ func main() { } // Seek key-value pair whose value starts with "b" - seek := func(key interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(key int, value string) bool { + return strings.HasSuffix(value, "b") } it.Begin() diff --git a/examples/linkedhashmap/linkedhashmap.go b/examples/linkedhashmap/linkedhashmap.go index 64434817..fe33f49d 100644 --- a/examples/linkedhashmap/linkedhashmap.go +++ b/examples/linkedhashmap/linkedhashmap.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/maps/linkedhashmap" +import "github.com/emirpasic/gods/v2/maps/linkedhashmap" // LinkedHashMapExample to demonstrate basic usage of LinkedHashMapExample func main() { - m := linkedhashmap.New() // empty (keys are of type int) - m.Put(2, "b") // 2->b - m.Put(1, "x") // 2->b, 1->x (insertion-order) - m.Put(1, "a") // 2->b, 1->a (insertion-order) - _, _ = m.Get(2) // b, true - _, _ = m.Get(3) // nil, false - _ = m.Values() // []interface {}{"b", "a"} (insertion-order) - _ = m.Keys() // []interface {}{2, 1} (insertion-order) - m.Remove(1) // 2->b - m.Clear() // empty - m.Empty() // true - m.Size() // 0 + m := linkedhashmap.New[int, string]() // empty (keys are of type int) + m.Put(2, "b") // 2->b + m.Put(1, "x") // 2->b, 1->x (insertion-order) + m.Put(1, "a") // 2->b, 1->a (insertion-order) + _, _ = m.Get(2) // b, true + _, _ = m.Get(3) // nil, false + _ = m.Values() // []interface {}{"b", "a"} (insertion-order) + _ = m.Keys() // []interface {}{2, 1} (insertion-order) + m.Remove(1) // 2->b + m.Clear() // empty + m.Empty() // true + m.Size() // 0 } diff --git a/examples/linkedhashset/linkedhashset.go b/examples/linkedhashset/linkedhashset.go index 689d212d..e6a5d69c 100644 --- a/examples/linkedhashset/linkedhashset.go +++ b/examples/linkedhashset/linkedhashset.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/sets/linkedhashset" +import "github.com/emirpasic/gods/v2/sets/linkedhashset" // LinkedHashSetExample to demonstrate basic usage of LinkedHashSet func main() { - set := linkedhashset.New() // empty - set.Add(5) // 5 - set.Add(4, 4, 3, 2, 1) // 5, 4, 3, 2, 1 (in insertion-order, duplicates ignored) - set.Remove(4) // 5, 3, 2, 1 (in insertion-order) - set.Remove(2, 3) // 5, 1 (in insertion-order) - set.Contains(1) // true - set.Contains(1, 5) // true - set.Contains(1, 6) // false - _ = set.Values() // []int{5, 1} (in insertion-order) - set.Clear() // empty - set.Empty() // true - set.Size() // 0 + set := linkedhashset.New[int]() // empty + set.Add(5) // 5 + set.Add(4, 4, 3, 2, 1) // 5, 4, 3, 2, 1 (in insertion-order, duplicates ignored) + set.Remove(4) // 5, 3, 2, 1 (in insertion-order) + set.Remove(2, 3) // 5, 1 (in insertion-order) + set.Contains(1) // true + set.Contains(1, 5) // true + set.Contains(1, 6) // false + _ = set.Values() // []int{5, 1} (in insertion-order) + set.Clear() // empty + set.Empty() // true + set.Size() // 0 } diff --git a/examples/linkedlistqueue/linkedlistqueue.go b/examples/linkedlistqueue/linkedlistqueue.go index d6800b5f..2a61d2f5 100644 --- a/examples/linkedlistqueue/linkedlistqueue.go +++ b/examples/linkedlistqueue/linkedlistqueue.go @@ -4,20 +4,20 @@ package main -import llq "github.com/emirpasic/gods/queues/linkedlistqueue" +import llq "github.com/emirpasic/gods/v2/queues/linkedlistqueue" // LinkedListQueueExample to demonstrate basic usage of LinkedListQueue func main() { - queue := llq.New() // empty - queue.Enqueue(1) // 1 - queue.Enqueue(2) // 1, 2 - _ = queue.Values() // 1, 2 (FIFO order) - _, _ = queue.Peek() // 1,true - _, _ = queue.Dequeue() // 1, true - _, _ = queue.Dequeue() // 2, true - _, _ = queue.Dequeue() // nil, false (nothing to deque) - queue.Enqueue(1) // 1 - queue.Clear() // empty - queue.Empty() // true - _ = queue.Size() // 0 + queue := llq.New[int]() // empty + queue.Enqueue(1) // 1 + queue.Enqueue(2) // 1, 2 + _ = queue.Values() // 1, 2 (FIFO order) + _, _ = queue.Peek() // 1,true + _, _ = queue.Dequeue() // 1, true + _, _ = queue.Dequeue() // 2, true + _, _ = queue.Dequeue() // nil, false (nothing to deque) + queue.Enqueue(1) // 1 + queue.Clear() // empty + queue.Empty() // true + _ = queue.Size() // 0 } diff --git a/examples/linkedliststack/linkedliststack.go b/examples/linkedliststack/linkedliststack.go index e9f1a68e..81ace6c6 100644 --- a/examples/linkedliststack/linkedliststack.go +++ b/examples/linkedliststack/linkedliststack.go @@ -4,20 +4,20 @@ package main -import lls "github.com/emirpasic/gods/stacks/linkedliststack" +import lls "github.com/emirpasic/gods/v2/stacks/linkedliststack" // LinkedListStackExample to demonstrate basic usage of LinkedListStack func main() { - stack := lls.New() // empty - stack.Push(1) // 1 - stack.Push(2) // 1, 2 - stack.Values() // 2, 1 (LIFO order) - _, _ = stack.Peek() // 2,true - _, _ = stack.Pop() // 2, true - _, _ = stack.Pop() // 1, true - _, _ = stack.Pop() // nil, false (nothing to pop) - stack.Push(1) // 1 - stack.Clear() // empty - stack.Empty() // true - stack.Size() // 0 + stack := lls.New[int]() // empty + stack.Push(1) // 1 + stack.Push(2) // 1, 2 + stack.Values() // 2, 1 (LIFO order) + _, _ = stack.Peek() // 2,true + _, _ = stack.Pop() // 2, true + _, _ = stack.Pop() // 1, true + _, _ = stack.Pop() // nil, false (nothing to pop) + stack.Push(1) // 1 + stack.Clear() // empty + stack.Empty() // true + stack.Size() // 0 } diff --git a/examples/priorityqueue/priorityqueue.go b/examples/priorityqueue/priorityqueue.go index 11fd1e57..e77f2210 100644 --- a/examples/priorityqueue/priorityqueue.go +++ b/examples/priorityqueue/priorityqueue.go @@ -5,8 +5,9 @@ package main import ( - pq "github.com/emirpasic/gods/queues/priorityqueue" - "github.com/emirpasic/gods/utils" + "cmp" + + pq "github.com/emirpasic/gods/v2/queues/priorityqueue" ) // Element is an entry in the priority queue @@ -16,10 +17,8 @@ type Element struct { } // Comparator function (sort by element's priority value in descending order) -func byPriority(a, b interface{}) int { - priorityA := a.(Element).priority - priorityB := b.(Element).priority - return -utils.IntComparator(priorityA, priorityB) // "-" descending order +func byPriority(a, b Element) int { + return -cmp.Compare(a.priority, b.priority) // "-" descending order } // PriorityQueueExample to demonstrate basic usage of BinaryHeap diff --git a/examples/redblacktree/redblacktree.go b/examples/redblacktree/redblacktree.go index b7d9803f..6661b166 100644 --- a/examples/redblacktree/redblacktree.go +++ b/examples/redblacktree/redblacktree.go @@ -6,12 +6,13 @@ package main import ( "fmt" - rbt "github.com/emirpasic/gods/trees/redblacktree" + + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // RedBlackTreeExample to demonstrate basic usage of RedBlackTree func main() { - tree := rbt.NewWithIntComparator() // empty(keys are of type int) + tree := rbt.New[int, string]() // empty(keys are of type int) tree.Put(1, "x") // 1->x tree.Put(2, "b") // 1->x, 2->b (in order) diff --git a/examples/redblacktreeextended/redblacktreeextended.go b/examples/redblacktreeextended/redblacktreeextended.go index 6e901299..602dfecc 100644 --- a/examples/redblacktreeextended/redblacktreeextended.go +++ b/examples/redblacktreeextended/redblacktreeextended.go @@ -6,53 +6,54 @@ package redblacktreeextended import ( "fmt" - rbt "github.com/emirpasic/gods/trees/redblacktree" + + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // RedBlackTreeExtended to demonstrate how to extend a RedBlackTree to include new functions -type RedBlackTreeExtended struct { - *rbt.Tree +type RedBlackTreeExtended[K comparable, V any] struct { + *rbt.Tree[K, V] } // GetMin gets the min value and flag if found -func (tree *RedBlackTreeExtended) GetMin() (value interface{}, found bool) { +func (tree *RedBlackTreeExtended[K, V]) GetMin() (value V, found bool) { node, found := tree.getMinFromNode(tree.Root) - if node != nil { + if found { return node.Value, found } - return nil, false + return value, false } // GetMax gets the max value and flag if found -func (tree *RedBlackTreeExtended) GetMax() (value interface{}, found bool) { +func (tree *RedBlackTreeExtended[K, V]) GetMax() (value V, found bool) { node, found := tree.getMaxFromNode(tree.Root) - if node != nil { + if found { return node.Value, found } - return nil, false + return value, false } // RemoveMin removes the min value and flag if found -func (tree *RedBlackTreeExtended) RemoveMin() (value interface{}, deleted bool) { +func (tree *RedBlackTreeExtended[K, V]) RemoveMin() (value V, deleted bool) { node, found := tree.getMinFromNode(tree.Root) if found { tree.Remove(node.Key) return node.Value, found } - return nil, false + return value, false } // RemoveMax removes the max value and flag if found -func (tree *RedBlackTreeExtended) RemoveMax() (value interface{}, deleted bool) { +func (tree *RedBlackTreeExtended[K, V]) RemoveMax() (value V, deleted bool) { node, found := tree.getMaxFromNode(tree.Root) if found { tree.Remove(node.Key) return node.Value, found } - return nil, false + return value, false } -func (tree *RedBlackTreeExtended) getMinFromNode(node *rbt.Node) (foundNode *rbt.Node, found bool) { +func (tree *RedBlackTreeExtended[K, V]) getMinFromNode(node *rbt.Node[K, V]) (foundNode *rbt.Node[K, V], found bool) { if node == nil { return nil, false } @@ -62,7 +63,7 @@ func (tree *RedBlackTreeExtended) getMinFromNode(node *rbt.Node) (foundNode *rbt return tree.getMinFromNode(node.Left) } -func (tree *RedBlackTreeExtended) getMaxFromNode(node *rbt.Node) (foundNode *rbt.Node, found bool) { +func (tree *RedBlackTreeExtended[K, V]) getMaxFromNode(node *rbt.Node[K, V]) (foundNode *rbt.Node[K, V], found bool) { if node == nil { return nil, false } @@ -72,7 +73,7 @@ func (tree *RedBlackTreeExtended) getMaxFromNode(node *rbt.Node) (foundNode *rbt return tree.getMaxFromNode(node.Right) } -func print(tree *RedBlackTreeExtended) { +func print(tree *RedBlackTreeExtended[int, string]) { max, _ := tree.GetMax() min, _ := tree.GetMin() fmt.Printf("Value for max key: %v \n", max) @@ -82,7 +83,7 @@ func print(tree *RedBlackTreeExtended) { // RedBlackTreeExtendedExample main method on how to use the custom red-black tree above func main() { - tree := RedBlackTreeExtended{rbt.NewWithIntComparator()} + tree := RedBlackTreeExtended[int, string]{rbt.New[int, string]()} tree.Put(1, "a") // 1->x (in order) tree.Put(2, "b") // 1->x, 2->b (in order) diff --git a/examples/serialization/serialization.go b/examples/serialization/serialization.go index 2f94c5ed..c03252ce 100644 --- a/examples/serialization/serialization.go +++ b/examples/serialization/serialization.go @@ -2,13 +2,14 @@ package serialization import ( "fmt" - "github.com/emirpasic/gods/lists/arraylist" - "github.com/emirpasic/gods/maps/hashmap" + + "github.com/emirpasic/gods/v2/lists/arraylist" + "github.com/emirpasic/gods/v2/maps/hashmap" ) // ListSerializationExample demonstrates how to serialize and deserialize lists to and from JSON func ListSerializationExample() { - list := arraylist.New() + list := arraylist.New[string]() list.Add("a", "b", "c") // Serialization (marshalling) @@ -29,7 +30,7 @@ func ListSerializationExample() { // MapSerializationExample demonstrates how to serialize and deserialize maps to and from JSON func MapSerializationExample() { - m := hashmap.New() + m := hashmap.New[string, string]() m.Put("a", "1") m.Put("b", "2") m.Put("c", "3") diff --git a/examples/singlylinkedlist/singlylinkedlist.go b/examples/singlylinkedlist/singlylinkedlist.go index 93b4ccaf..3d327848 100644 --- a/examples/singlylinkedlist/singlylinkedlist.go +++ b/examples/singlylinkedlist/singlylinkedlist.go @@ -5,17 +5,18 @@ package main import ( - sll "github.com/emirpasic/gods/lists/singlylinkedlist" - "github.com/emirpasic/gods/utils" + "cmp" + + sll "github.com/emirpasic/gods/v2/lists/singlylinkedlist" ) // SinglyLinkedListExample to demonstrate basic usage of SinglyLinkedList func main() { - list := sll.New() + list := sll.New[string]() list.Add("a") // ["a"] list.Append("b") // ["a","b"] (same as Add()) list.Prepend("c") // ["c","a","b"] - list.Sort(utils.StringComparator) // ["a","b","c"] + list.Sort(cmp.Compare[string]) // ["a","b","c"] _, _ = list.Get(0) // "a",true _, _ = list.Get(100) // nil,false _ = list.Contains("a", "b", "c") // true diff --git a/examples/treebidimap/treebidimap.go b/examples/treebidimap/treebidimap.go index 0c63f122..f378eca2 100644 --- a/examples/treebidimap/treebidimap.go +++ b/examples/treebidimap/treebidimap.go @@ -5,13 +5,12 @@ package main import ( - "github.com/emirpasic/gods/maps/treebidimap" - "github.com/emirpasic/gods/utils" + "github.com/emirpasic/gods/v2/maps/treebidimap" ) // TreeBidiMapExample to demonstrate basic usage of TreeBidiMap func main() { - m := treebidimap.NewWith(utils.IntComparator, utils.StringComparator) + m := treebidimap.New[int, string]() m.Put(1, "x") // 1->x m.Put(3, "b") // 1->x, 3->b (ordered) m.Put(1, "a") // 1->a, 3->b (ordered) diff --git a/examples/treemap/treemap.go b/examples/treemap/treemap.go index 66b62cc4..4a1a21b2 100644 --- a/examples/treemap/treemap.go +++ b/examples/treemap/treemap.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/maps/treemap" +import "github.com/emirpasic/gods/v2/maps/treemap" // TreeMapExample to demonstrate basic usage of TreeMap func main() { - m := treemap.NewWithIntComparator() // empty (keys are of type int) - m.Put(1, "x") // 1->x - m.Put(2, "b") // 1->x, 2->b (in order) - m.Put(1, "a") // 1->a, 2->b (in order) - _, _ = m.Get(2) // b, true - _, _ = m.Get(3) // nil, false - _ = m.Values() // []interface {}{"a", "b"} (in order) - _ = m.Keys() // []interface {}{1, 2} (in order) - m.Remove(1) // 2->b - m.Clear() // empty - m.Empty() // true - m.Size() // 0 + m := treemap.New[int, string]() // empty + m.Put(1, "x") // 1->x + m.Put(2, "b") // 1->x, 2->b (in order) + m.Put(1, "a") // 1->a, 2->b (in order) + _, _ = m.Get(2) // b, true + _, _ = m.Get(3) // nil, false + _ = m.Values() // []interface {}{"a", "b"} (in order) + _ = m.Keys() // []interface {}{1, 2} (in order) + m.Remove(1) // 2->b + m.Clear() // empty + m.Empty() // true + m.Size() // 0 } diff --git a/examples/treeset/treeset.go b/examples/treeset/treeset.go index 15a7f81a..e56b148e 100644 --- a/examples/treeset/treeset.go +++ b/examples/treeset/treeset.go @@ -4,20 +4,20 @@ package main -import "github.com/emirpasic/gods/sets/treeset" +import "github.com/emirpasic/gods/v2/sets/treeset" // TreeSetExample to demonstrate basic usage of TreeSet func main() { - set := treeset.NewWithIntComparator() // empty - set.Add(1) // 1 - set.Add(2, 2, 3, 4, 5) // 1, 2, 3, 4, 5 (in order, duplicates ignored) - set.Remove(4) // 1, 2, 3, 5 (in order) - set.Remove(2, 3) // 1, 5 (in order) - set.Contains(1) // true - set.Contains(1, 5) // true - set.Contains(1, 6) // false - _ = set.Values() // []int{1,5} (in order) - set.Clear() // empty - set.Empty() // true - set.Size() // 0 + set := treeset.New[int]() // empty + set.Add(1) // 1 + set.Add(2, 2, 3, 4, 5) // 1, 2, 3, 4, 5 (in order, duplicates ignored) + set.Remove(4) // 1, 2, 3, 5 (in order) + set.Remove(2, 3) // 1, 5 (in order) + set.Contains(1) // true + set.Contains(1, 5) // true + set.Contains(1, 6) // false + _ = set.Values() // []int{1,5} (in order) + set.Clear() // empty + set.Empty() // true + set.Size() // 0 } diff --git a/go.mod b/go.mod index 5160ad04..4d864ff5 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/emirpasic/gods +module github.com/emirpasic/gods/v2 -go 1.2 +go 1.21 diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..b5ad6665 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= diff --git a/lists/arraylist/arraylist.go b/lists/arraylist/arraylist.go index 60ce4583..e67ff292 100644 --- a/lists/arraylist/arraylist.go +++ b/lists/arraylist/arraylist.go @@ -11,18 +11,19 @@ package arraylist import ( "fmt" + "slices" "strings" - "github.com/emirpasic/gods/lists" - "github.com/emirpasic/gods/utils" + "github.com/emirpasic/gods/v2/lists" + "github.com/emirpasic/gods/v2/utils" ) // Assert List implementation -var _ lists.List = (*List)(nil) +var _ lists.List[int] = (*List[int])(nil) // List holds the elements in a slice -type List struct { - elements []interface{} +type List[T comparable] struct { + elements []T size int } @@ -32,8 +33,8 @@ const ( ) // New instantiates a new list and adds the passed values, if any, to the list -func New(values ...interface{}) *List { - list := &List{} +func New[T comparable](values ...T) *List[T] { + list := &List[T]{} if len(values) > 0 { list.Add(values...) } @@ -41,7 +42,7 @@ func New(values ...interface{}) *List { } // Add appends a value at the end of the list -func (list *List) Add(values ...interface{}) { +func (list *List[T]) Add(values ...T) { list.growBy(len(values)) for _, value := range values { list.elements[list.size] = value @@ -51,23 +52,24 @@ func (list *List) Add(values ...interface{}) { // Get returns the element at index. // Second return parameter is true if index is within bounds of the array and array is not empty, otherwise false. -func (list *List) Get(index int) (interface{}, bool) { +func (list *List[T]) Get(index int) (T, bool) { if !list.withinRange(index) { - return nil, false + var t T + return t, false } return list.elements[index], true } // Remove removes the element at the given index from the list. -func (list *List) Remove(index int) { +func (list *List[T]) Remove(index int) { if !list.withinRange(index) { return } - list.elements[index] = nil // cleanup reference + clear(list.elements[index : index+1]) copy(list.elements[index:], list.elements[index+1:list.size]) // shift to the left by one (slow operation, need ways to optimize this) list.size-- @@ -78,7 +80,7 @@ func (list *List) Remove(index int) { // All elements have to be present in the set for the method to return true. // Performance time complexity of n^2. // Returns true if no arguments are passed at all, i.e. set is always super-set of empty set. -func (list *List) Contains(values ...interface{}) bool { +func (list *List[T]) Contains(values ...T) bool { for _, searchValue := range values { found := false @@ -96,14 +98,14 @@ func (list *List) Contains(values ...interface{}) bool { } // Values returns all elements in the list. -func (list *List) Values() []interface{} { - newElements := make([]interface{}, list.size, list.size) +func (list *List[T]) Values() []T { + newElements := make([]T, list.size, list.size) copy(newElements, list.elements[:list.size]) return newElements } -//IndexOf returns index of provided element -func (list *List) IndexOf(value interface{}) int { +// IndexOf returns index of provided element +func (list *List[T]) IndexOf(value T) int { if list.size == 0 { return -1 } @@ -116,31 +118,31 @@ func (list *List) IndexOf(value interface{}) int { } // Empty returns true if list does not contain any elements. -func (list *List) Empty() bool { +func (list *List[T]) Empty() bool { return list.size == 0 } // Size returns number of elements within the list. -func (list *List) Size() int { +func (list *List[T]) Size() int { return list.size } // Clear removes all elements from the list. -func (list *List) Clear() { +func (list *List[T]) Clear() { list.size = 0 - list.elements = []interface{}{} + list.elements = []T{} } // Sort sorts values (in-place) using. -func (list *List) Sort(comparator utils.Comparator) { +func (list *List[T]) Sort(comparator utils.Comparator[T]) { if len(list.elements) < 2 { return } - utils.Sort(list.elements[:list.size], comparator) + slices.SortFunc(list.elements[:list.size], comparator) } // Swap swaps the two values at the specified positions. -func (list *List) Swap(i, j int) { +func (list *List[T]) Swap(i, j int) { if list.withinRange(i) && list.withinRange(j) { list.elements[i], list.elements[j] = list.elements[j], list.elements[i] } @@ -149,7 +151,7 @@ func (list *List) Swap(i, j int) { // Insert inserts values at specified index position shifting the value at that position (if any) and any subsequent elements to the right. // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Insert(index int, values ...interface{}) { +func (list *List[T]) Insert(index int, values ...T) { if !list.withinRange(index) { // Append @@ -169,7 +171,7 @@ func (list *List) Insert(index int, values ...interface{}) { // Set the value at specified index // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Set(index int, value interface{}) { +func (list *List[T]) Set(index int, value T) { if !list.withinRange(index) { // Append @@ -183,9 +185,9 @@ func (list *List) Set(index int, value interface{}) { } // String returns a string representation of container -func (list *List) String() string { +func (list *List[T]) String() string { str := "ArrayList\n" - values := []string{} + values := make([]string, 0, list.size) for _, value := range list.elements[:list.size] { values = append(values, fmt.Sprintf("%v", value)) } @@ -194,18 +196,18 @@ func (list *List) String() string { } // Check that the index is within bounds of the list -func (list *List) withinRange(index int) bool { +func (list *List[T]) withinRange(index int) bool { return index >= 0 && index < list.size } -func (list *List) resize(cap int) { - newElements := make([]interface{}, cap, cap) +func (list *List[T]) resize(cap int) { + newElements := make([]T, cap, cap) copy(newElements, list.elements) list.elements = newElements } // Expand the array if necessary, i.e. capacity will be reached if we add n elements -func (list *List) growBy(n int) { +func (list *List[T]) growBy(n int) { // When capacity is reached, grow by a factor of growthFactor and add number of elements currentCapacity := cap(list.elements) if list.size+n >= currentCapacity { @@ -215,7 +217,7 @@ func (list *List) growBy(n int) { } // Shrink the array if necessary, i.e. when size is shrinkFactor percent of current capacity -func (list *List) shrink() { +func (list *List[T]) shrink() { if shrinkFactor == 0.0 { return } diff --git a/lists/arraylist/arraylist_test.go b/lists/arraylist/arraylist_test.go index 3b7c8d77..614b436d 100644 --- a/lists/arraylist/arraylist_test.go +++ b/lists/arraylist/arraylist_test.go @@ -5,21 +5,21 @@ package arraylist import ( + "cmp" "encoding/json" - "fmt" - "github.com/emirpasic/gods/utils" + "slices" "strings" "testing" ) func TestListNew(t *testing.T) { - list1 := New() + list1 := New[int]() if actualValue := list1.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } - list2 := New(1, "b") + list2 := New[int](1, 2) if actualValue := list2.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) @@ -29,17 +29,17 @@ func TestListNew(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, 1) } - if actualValue, ok := list2.Get(1); actualValue != "b" || !ok { - t.Errorf("Got %v expected %v", actualValue, "b") + if actualValue, ok := list2.Get(1); actualValue != 2 || !ok { + t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, ok := list2.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list2.Get(2); actualValue != 0 || ok { + t.Errorf("Got %v expected %v", actualValue, 0) } } func TestListAdd(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Empty(); actualValue != false { @@ -54,7 +54,7 @@ func TestListAdd(t *testing.T) { } func TestListIndexOf(t *testing.T) { - list := New() + list := New[string]() expectedIndex := -1 if index := list.IndexOf("a"); index != expectedIndex { @@ -81,12 +81,12 @@ func TestListIndexOf(t *testing.T) { } func TestListRemove(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Remove(2) - if actualValue, ok := list.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(2); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(1) list.Remove(0) @@ -100,7 +100,7 @@ func TestListRemove(t *testing.T) { } func TestListGet(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue, ok := list.Get(0); actualValue != "a" || !ok { @@ -112,8 +112,8 @@ func TestListGet(t *testing.T) { if actualValue, ok := list.Get(2); actualValue != "c" || !ok { t.Errorf("Got %v expected %v", actualValue, "c") } - if actualValue, ok := list.Get(3); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(3); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(0) if actualValue, ok := list.Get(0); actualValue != "b" || !ok { @@ -122,7 +122,7 @@ func TestListGet(t *testing.T) { } func TestListSwap(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Swap(0, 1) @@ -132,21 +132,21 @@ func TestListSwap(t *testing.T) { } func TestListSort(t *testing.T) { - list := New() - list.Sort(utils.StringComparator) + list := New[string]() + list.Sort(cmp.Compare[string]) list.Add("e", "f", "g", "a", "b", "c", "d") - list.Sort(utils.StringComparator) + list.Sort(cmp.Compare[string]) for i := 1; i < list.Size(); i++ { a, _ := list.Get(i - 1) b, _ := list.Get(i) - if a.(string) > b.(string) { + if a > b { t.Errorf("Not sorted! %s > %s", a, b) } } } func TestListClear(t *testing.T) { - list := New() + list := New[string]() list.Add("e", "f", "g", "a", "b", "c", "d") list.Clear() if actualValue := list.Empty(); actualValue != true { @@ -158,13 +158,13 @@ func TestListClear(t *testing.T) { } func TestListContains(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Contains("a"); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } - if actualValue := list.Contains(nil); actualValue != false { + if actualValue := list.Contains(""); actualValue != false { t.Errorf("Got %v expected %v", actualValue, false) } if actualValue := list.Contains("a", "b", "c"); actualValue != true { @@ -183,16 +183,16 @@ func TestListContains(t *testing.T) { } func TestListValues(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListInsert(t *testing.T) { - list := New() + list := New[string]() list.Insert(0, "b", "c") list.Insert(0, "a") list.Insert(10, "x") // ignore @@ -203,13 +203,13 @@ func TestListInsert(t *testing.T) { if actualValue := list.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", list.Values()...), "abcd"; actualValue != expectedValue { + if actualValue, expectedValue := strings.Join(list.Values(), ""), "abcd"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListSet(t *testing.T) { - list := New() + list := New[string]() list.Set(0, "a") list.Set(1, "b") if actualValue := list.Size(); actualValue != 2 { @@ -224,15 +224,15 @@ func TestListSet(t *testing.T) { if actualValue := list.Size(); actualValue != 3 { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abbc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "bb", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListEach(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - list.Each(func(index int, value interface{}) { + list.Each(func(index int, value string) { switch index { case 0: if actualValue, expectedValue := value, "a"; actualValue != expectedValue { @@ -253,10 +253,10 @@ func TestListEach(t *testing.T) { } func TestListMap(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - mappedList := list.Map(func(index int, value interface{}) interface{} { - return "mapped: " + value.(string) + mappedList := list.Map(func(index int, value string) string { + return "mapped: " + value }) if actualValue, _ := mappedList.Get(0); actualValue != "mapped: a" { t.Errorf("Got %v expected %v", actualValue, "mapped: a") @@ -273,10 +273,10 @@ func TestListMap(t *testing.T) { } func TestListSelect(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - selectedList := list.Select(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + selectedList := list.Select(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if actualValue, _ := selectedList.Get(0); actualValue != "a" { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -290,60 +290,60 @@ func TestListSelect(t *testing.T) { } func TestListAny(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - any := list.Any(func(index int, value interface{}) bool { - return value.(string) == "c" + any := list.Any(func(index int, value string) bool { + return value == "c" }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = list.Any(func(index int, value interface{}) bool { - return value.(string) == "x" + any = list.Any(func(index int, value string) bool { + return value == "x" }) if any != false { t.Errorf("Got %v expected %v", any, false) } } func TestListAll(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - all := list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "c" + all := list.All(func(index int, value string) bool { + return value >= "a" && value <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + all = list.All(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) } } func TestListFind(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - foundIndex, foundValue := list.Find(func(index int, value interface{}) bool { - return value.(string) == "c" + foundIndex, foundValue := list.Find(func(index int, value string) bool { + return value == "c" }) if foundValue != "c" || foundIndex != 2 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 2) } - foundIndex, foundValue = list.Find(func(index int, value interface{}) bool { - return value.(string) == "x" + foundIndex, foundValue = list.Find(func(index int, value string) bool { + return value == "x" }) - if foundValue != nil || foundIndex != -1 { + if foundValue != "" || foundIndex != -1 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) } } func TestListChaining(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - chainedList := list.Select(func(index int, value interface{}) bool { - return value.(string) > "a" - }).Map(func(index int, value interface{}) interface{} { - return value.(string) + value.(string) + chainedList := list.Select(func(index int, value string) bool { + return value > "a" + }).Map(func(index int, value string) string { + return value + value }) if chainedList.Size() != 2 { t.Errorf("Got %v expected %v", chainedList.Size(), 2) @@ -357,7 +357,7 @@ func TestListChaining(t *testing.T) { } func TestListIteratorNextOnEmpty(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty list") @@ -365,7 +365,7 @@ func TestListIteratorNextOnEmpty(t *testing.T) { } func TestListIteratorNext(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") it := list.Iterator() count := 0 @@ -396,7 +396,7 @@ func TestListIteratorNext(t *testing.T) { } func TestListIteratorPrevOnEmpty(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty list") @@ -404,7 +404,7 @@ func TestListIteratorPrevOnEmpty(t *testing.T) { } func TestListIteratorPrev(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") it := list.Iterator() for it.Next() { @@ -437,7 +437,7 @@ func TestListIteratorPrev(t *testing.T) { } func TestListIteratorBegin(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() it.Begin() list.Add("a", "b", "c") @@ -451,7 +451,7 @@ func TestListIteratorBegin(t *testing.T) { } func TestListIteratorEnd(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if index := it.Index(); index != -1 { @@ -476,7 +476,7 @@ func TestListIteratorEnd(t *testing.T) { } func TestListIteratorFirst(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -491,7 +491,7 @@ func TestListIteratorFirst(t *testing.T) { } func TestListIteratorLast(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -507,13 +507,13 @@ func TestListIteratorLast(t *testing.T) { func TestListIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - list := New() + list := New[string]() it := list.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") @@ -522,7 +522,7 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (not found) { - list := New() + list := New[string]() list.Add("xx", "yy") it := list.Iterator() for it.NextTo(seek) { @@ -532,20 +532,20 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (found) { - list := New() + list := New[string]() list.Add("aa", "bb", "cc") it := list.Iterator() it.Begin() if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -556,13 +556,13 @@ func TestListIteratorNextTo(t *testing.T) { func TestListIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - list := New() + list := New[string]() it := list.Iterator() it.End() for it.PrevTo(seek) { @@ -572,7 +572,7 @@ func TestListIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - list := New() + list := New[string]() list.Add("xx", "yy") it := list.Iterator() it.End() @@ -583,20 +583,20 @@ func TestListIteratorPrevTo(t *testing.T) { // PrevTo (found) { - list := New() + list := New[string]() list.Add("aa", "bb", "cc") it := list.Iterator() it.End() if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -606,12 +606,12 @@ func TestListIteratorPrevTo(t *testing.T) { } func TestListSerialization(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue, expectedValue := list.Size(), 3; actualValue != expectedValue { @@ -630,26 +630,27 @@ func TestListSerialization(t *testing.T) { err = list.FromJSON(bytes) assert() - bytes, err = json.Marshal([]interface{}{"a", "b", "c", list}) + bytes, err = json.Marshal([]any{"a", "b", "c", list}) if err != nil { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &list) + err = json.Unmarshal([]byte(`["a","b","c"]`), &list) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestListString(t *testing.T) { - c := New() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "ArrayList") { t.Errorf("String should start with container name") } } -func benchmarkGet(b *testing.B, list *List, size int) { +func benchmarkGet(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Get(n) @@ -657,7 +658,7 @@ func benchmarkGet(b *testing.B, list *List, size int) { } } -func benchmarkAdd(b *testing.B, list *List, size int) { +func benchmarkAdd(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Add(n) @@ -665,7 +666,7 @@ func benchmarkAdd(b *testing.B, list *List, size int) { } } -func benchmarkRemove(b *testing.B, list *List, size int) { +func benchmarkRemove(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Remove(n) @@ -676,7 +677,7 @@ func benchmarkRemove(b *testing.B, list *List, size int) { func BenchmarkArrayListGet100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -687,7 +688,7 @@ func BenchmarkArrayListGet100(b *testing.B) { func BenchmarkArrayListGet1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -698,7 +699,7 @@ func BenchmarkArrayListGet1000(b *testing.B) { func BenchmarkArrayListGet10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -709,7 +710,7 @@ func BenchmarkArrayListGet10000(b *testing.B) { func BenchmarkArrayListGet100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -720,7 +721,7 @@ func BenchmarkArrayListGet100000(b *testing.B) { func BenchmarkArrayListAdd100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() b.StartTimer() benchmarkAdd(b, list, size) } @@ -728,7 +729,7 @@ func BenchmarkArrayListAdd100(b *testing.B) { func BenchmarkArrayListAdd1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -739,7 +740,7 @@ func BenchmarkArrayListAdd1000(b *testing.B) { func BenchmarkArrayListAdd10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -750,7 +751,7 @@ func BenchmarkArrayListAdd10000(b *testing.B) { func BenchmarkArrayListAdd100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -761,7 +762,7 @@ func BenchmarkArrayListAdd100000(b *testing.B) { func BenchmarkArrayListRemove100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -772,7 +773,7 @@ func BenchmarkArrayListRemove100(b *testing.B) { func BenchmarkArrayListRemove1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -783,7 +784,7 @@ func BenchmarkArrayListRemove1000(b *testing.B) { func BenchmarkArrayListRemove10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -794,7 +795,7 @@ func BenchmarkArrayListRemove10000(b *testing.B) { func BenchmarkArrayListRemove100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } diff --git a/lists/arraylist/enumerable.go b/lists/arraylist/enumerable.go index 8bd60b0a..94589d72 100644 --- a/lists/arraylist/enumerable.go +++ b/lists/arraylist/enumerable.go @@ -4,13 +4,13 @@ package arraylist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithIndex = (*List)(nil) +var _ containers.EnumerableWithIndex[int] = (*List[int])(nil) // Each calls the given function once for each element, passing that element's index and value. -func (list *List) Each(f func(index int, value interface{})) { +func (list *List[T]) Each(f func(index int, value T)) { iterator := list.Iterator() for iterator.Next() { f(iterator.Index(), iterator.Value()) @@ -19,8 +19,8 @@ func (list *List) Each(f func(index int, value interface{})) { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. -func (list *List) Map(f func(index int, value interface{}) interface{}) *List { - newList := &List{} +func (list *List[T]) Map(f func(index int, value T) T) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { newList.Add(f(iterator.Index(), iterator.Value())) @@ -29,8 +29,8 @@ func (list *List) Map(f func(index int, value interface{}) interface{}) *List { } // Select returns a new container containing all elements for which the given function returns a true value. -func (list *List) Select(f func(index int, value interface{}) bool) *List { - newList := &List{} +func (list *List[T]) Select(f func(index int, value T) bool) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -42,7 +42,7 @@ func (list *List) Select(f func(index int, value interface{}) bool) *List { // Any passes each element of the collection to the given function and // returns true if the function ever returns true for any element. -func (list *List) Any(f func(index int, value interface{}) bool) bool { +func (list *List[T]) Any(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -54,7 +54,7 @@ func (list *List) Any(f func(index int, value interface{}) bool) bool { // All passes each element of the collection to the given function and // returns true if the function returns true for all elements. -func (list *List) All(f func(index int, value interface{}) bool) bool { +func (list *List[T]) All(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if !f(iterator.Index(), iterator.Value()) { @@ -67,12 +67,13 @@ func (list *List) All(f func(index int, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. -func (list *List) Find(f func(index int, value interface{}) bool) (int, interface{}) { +func (list *List[T]) Find(f func(index int, value T) bool) (int, T) { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { return iterator.Index(), iterator.Value() } } - return -1, nil + var t T + return -1, t } diff --git a/lists/arraylist/iterator.go b/lists/arraylist/iterator.go index f9efe20c..5a6705ce 100644 --- a/lists/arraylist/iterator.go +++ b/lists/arraylist/iterator.go @@ -4,27 +4,27 @@ package arraylist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator holding the iterator's state -type Iterator struct { - list *List +type Iterator[T comparable] struct { + list *List[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (list *List) Iterator() Iterator { - return Iterator{list: list, index: -1} +func (list *List[T]) Iterator() *Iterator[T] { + return &Iterator[T]{list: list, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.list.size { iterator.index++ } @@ -34,7 +34,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -43,32 +43,32 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.list.elements[iterator.index] } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.list.size } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -76,7 +76,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -85,7 +85,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -99,7 +99,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/lists/arraylist/serialization.go b/lists/arraylist/serialization.go index 5e86fe96..a3aec534 100644 --- a/lists/arraylist/serialization.go +++ b/lists/arraylist/serialization.go @@ -6,20 +6,21 @@ package arraylist import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*List)(nil) -var _ containers.JSONDeserializer = (*List)(nil) +var _ containers.JSONSerializer = (*List[int])(nil) +var _ containers.JSONDeserializer = (*List[int])(nil) // ToJSON outputs the JSON representation of list's elements. -func (list *List) ToJSON() ([]byte, error) { +func (list *List[T]) ToJSON() ([]byte, error) { return json.Marshal(list.elements[:list.size]) } // FromJSON populates list's elements from the input JSON representation. -func (list *List) FromJSON(data []byte) error { +func (list *List[T]) FromJSON(data []byte) error { err := json.Unmarshal(data, &list.elements) if err == nil { list.size = len(list.elements) @@ -28,11 +29,11 @@ func (list *List) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (list *List) UnmarshalJSON(bytes []byte) error { +func (list *List[T]) UnmarshalJSON(bytes []byte) error { return list.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (list *List) MarshalJSON() ([]byte, error) { +func (list *List[T]) MarshalJSON() ([]byte, error) { return list.ToJSON() } diff --git a/lists/doublylinkedlist/doublylinkedlist.go b/lists/doublylinkedlist/doublylinkedlist.go index d0e2b3a6..6a34e9de 100644 --- a/lists/doublylinkedlist/doublylinkedlist.go +++ b/lists/doublylinkedlist/doublylinkedlist.go @@ -11,31 +11,32 @@ package doublylinkedlist import ( "fmt" + "slices" "strings" - "github.com/emirpasic/gods/lists" - "github.com/emirpasic/gods/utils" + "github.com/emirpasic/gods/v2/lists" + "github.com/emirpasic/gods/v2/utils" ) // Assert List implementation -var _ lists.List = (*List)(nil) +var _ lists.List[any] = (*List[any])(nil) // List holds the elements, where each element points to the next and previous element -type List struct { - first *element - last *element +type List[T comparable] struct { + first *element[T] + last *element[T] size int } -type element struct { - value interface{} - prev *element - next *element +type element[T comparable] struct { + value T + prev *element[T] + next *element[T] } // New instantiates a new list and adds the passed values, if any, to the list -func New(values ...interface{}) *List { - list := &List{} +func New[T comparable](values ...T) *List[T] { + list := &List[T]{} if len(values) > 0 { list.Add(values...) } @@ -43,9 +44,9 @@ func New(values ...interface{}) *List { } // Add appends a value (one or more) at the end of the list (same as Append()) -func (list *List) Add(values ...interface{}) { +func (list *List[T]) Add(values ...T) { for _, value := range values { - newElement := &element{value: value, prev: list.last} + newElement := &element[T]{value: value, prev: list.last} if list.size == 0 { list.first = newElement list.last = newElement @@ -58,15 +59,15 @@ func (list *List) Add(values ...interface{}) { } // Append appends a value (one or more) at the end of the list (same as Add()) -func (list *List) Append(values ...interface{}) { +func (list *List[T]) Append(values ...T) { list.Add(values...) } // Prepend prepends a values (or more) -func (list *List) Prepend(values ...interface{}) { +func (list *List[T]) Prepend(values ...T) { // in reverse to keep passed order i.e. ["c","d"] -> Prepend(["a","b"]) -> ["a","b","c",d"] for v := len(values) - 1; v >= 0; v-- { - newElement := &element{value: values[v], next: list.first} + newElement := &element[T]{value: values[v], next: list.first} if list.size == 0 { list.first = newElement list.last = newElement @@ -80,10 +81,11 @@ func (list *List) Prepend(values ...interface{}) { // Get returns the element at index. // Second return parameter is true if index is within bounds of the array and array is not empty, otherwise false. -func (list *List) Get(index int) (interface{}, bool) { +func (list *List[T]) Get(index int) (T, bool) { if !list.withinRange(index) { - return nil, false + var t T + return t, false } // determine traveral direction, last to first or first to last @@ -100,7 +102,7 @@ func (list *List) Get(index int) (interface{}, bool) { } // Remove removes the element at the given index from the list. -func (list *List) Remove(index int) { +func (list *List[T]) Remove(index int) { if !list.withinRange(index) { return @@ -111,7 +113,7 @@ func (list *List) Remove(index int) { return } - var element *element + var element *element[T] // determine traversal direction, last to first or first to last if list.size-index < index { element = list.last @@ -145,7 +147,7 @@ func (list *List) Remove(index int) { // All values have to be present in the set for the method to return true. // Performance time complexity of n^2. // Returns true if no arguments are passed at all, i.e. set is always super-set of empty set. -func (list *List) Contains(values ...interface{}) bool { +func (list *List[T]) Contains(values ...T) bool { if len(values) == 0 { return true @@ -169,8 +171,8 @@ func (list *List) Contains(values ...interface{}) bool { } // Values returns all elements in the list. -func (list *List) Values() []interface{} { - values := make([]interface{}, list.size, list.size) +func (list *List[T]) Values() []T { + values := make([]T, list.size, list.size) for e, element := 0, list.first; element != nil; e, element = e+1, element.next { values[e] = element.value } @@ -178,7 +180,7 @@ func (list *List) Values() []interface{} { } // IndexOf returns index of provided element -func (list *List) IndexOf(value interface{}) int { +func (list *List[T]) IndexOf(value T) int { if list.size == 0 { return -1 } @@ -191,31 +193,31 @@ func (list *List) IndexOf(value interface{}) int { } // Empty returns true if list does not contain any elements. -func (list *List) Empty() bool { +func (list *List[T]) Empty() bool { return list.size == 0 } // Size returns number of elements within the list. -func (list *List) Size() int { +func (list *List[T]) Size() int { return list.size } // Clear removes all elements from the list. -func (list *List) Clear() { +func (list *List[T]) Clear() { list.size = 0 list.first = nil list.last = nil } // Sort sorts values (in-place) using. -func (list *List) Sort(comparator utils.Comparator) { +func (list *List[T]) Sort(comparator utils.Comparator[T]) { if list.size < 2 { return } values := list.Values() - utils.Sort(values, comparator) + slices.SortFunc(values, comparator) list.Clear() @@ -224,9 +226,9 @@ func (list *List) Sort(comparator utils.Comparator) { } // Swap swaps values of two elements at the given indices. -func (list *List) Swap(i, j int) { +func (list *List[T]) Swap(i, j int) { if list.withinRange(i) && list.withinRange(j) && i != j { - var element1, element2 *element + var element1, element2 *element[T] for e, currentElement := 0, list.first; element1 == nil || element2 == nil; e, currentElement = e+1, currentElement.next { switch e { case i: @@ -242,7 +244,7 @@ func (list *List) Swap(i, j int) { // Insert inserts values at specified index position shifting the value at that position (if any) and any subsequent elements to the right. // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Insert(index int, values ...interface{}) { +func (list *List[T]) Insert(index int, values ...T) { if !list.withinRange(index) { // Append @@ -252,8 +254,8 @@ func (list *List) Insert(index int, values ...interface{}) { return } - var beforeElement *element - var foundElement *element + var beforeElement *element[T] + var foundElement *element[T] // determine traversal direction, last to first or first to last if list.size-index < index { foundElement = list.last @@ -271,7 +273,7 @@ func (list *List) Insert(index int, values ...interface{}) { if foundElement == list.first { oldNextElement := list.first for i, value := range values { - newElement := &element{value: value} + newElement := &element[T]{value: value} if i == 0 { list.first = newElement } else { @@ -285,7 +287,7 @@ func (list *List) Insert(index int, values ...interface{}) { } else { oldNextElement := beforeElement.next for _, value := range values { - newElement := &element{value: value} + newElement := &element[T]{value: value} newElement.prev = beforeElement beforeElement.next = newElement beforeElement = newElement @@ -300,7 +302,7 @@ func (list *List) Insert(index int, values ...interface{}) { // Set value at specified index position // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Set(index int, value interface{}) { +func (list *List[T]) Set(index int, value T) { if !list.withinRange(index) { // Append @@ -310,7 +312,7 @@ func (list *List) Set(index int, value interface{}) { return } - var foundElement *element + var foundElement *element[T] // determine traversal direction, last to first or first to last if list.size-index < index { foundElement = list.last @@ -329,7 +331,7 @@ func (list *List) Set(index int, value interface{}) { } // String returns a string representation of container -func (list *List) String() string { +func (list *List[T]) String() string { str := "DoublyLinkedList\n" values := []string{} for element := list.first; element != nil; element = element.next { @@ -340,6 +342,6 @@ func (list *List) String() string { } // Check that the index is within bounds of the list -func (list *List) withinRange(index int) bool { +func (list *List[T]) withinRange(index int) bool { return index >= 0 && index < list.size } diff --git a/lists/doublylinkedlist/doublylinkedlist_test.go b/lists/doublylinkedlist/doublylinkedlist_test.go index 312d5ff2..4a4afdc7 100644 --- a/lists/doublylinkedlist/doublylinkedlist_test.go +++ b/lists/doublylinkedlist/doublylinkedlist_test.go @@ -5,22 +5,21 @@ package doublylinkedlist import ( + "cmp" "encoding/json" - "fmt" + "slices" "strings" "testing" - - "github.com/emirpasic/gods/utils" ) func TestListNew(t *testing.T) { - list1 := New() + list1 := New[int]() if actualValue := list1.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } - list2 := New(1, "b") + list2 := New[int](1, 2) if actualValue := list2.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) @@ -30,17 +29,17 @@ func TestListNew(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, 1) } - if actualValue, ok := list2.Get(1); actualValue != "b" || !ok { - t.Errorf("Got %v expected %v", actualValue, "b") + if actualValue, ok := list2.Get(1); actualValue != 2 || !ok { + t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, ok := list2.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list2.Get(2); actualValue != 0 || ok { + t.Errorf("Got %v expected %v", actualValue, 0) } } func TestListAdd(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Empty(); actualValue != false { @@ -55,7 +54,7 @@ func TestListAdd(t *testing.T) { } func TestListAppendAndPrepend(t *testing.T) { - list := New() + list := New[string]() list.Add("b") list.Prepend("a") list.Append("c") @@ -77,12 +76,12 @@ func TestListAppendAndPrepend(t *testing.T) { } func TestListRemove(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Remove(2) - if actualValue, ok := list.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(2); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(1) list.Remove(0) @@ -96,7 +95,7 @@ func TestListRemove(t *testing.T) { } func TestListGet(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue, ok := list.Get(0); actualValue != "a" || !ok { @@ -108,8 +107,8 @@ func TestListGet(t *testing.T) { if actualValue, ok := list.Get(2); actualValue != "c" || !ok { t.Errorf("Got %v expected %v", actualValue, "c") } - if actualValue, ok := list.Get(3); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(3); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(0) if actualValue, ok := list.Get(0); actualValue != "b" || !ok { @@ -118,31 +117,31 @@ func TestListGet(t *testing.T) { } func TestListSwap(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Swap(0, 1) if actualValue, ok := list.Get(0); actualValue != "b" || !ok { - t.Errorf("Got %v expected %v", actualValue, "c") + t.Errorf("Got %v expected %v", actualValue, "b") } } func TestListSort(t *testing.T) { - list := New() - list.Sort(utils.StringComparator) + list := New[string]() + list.Sort(cmp.Compare[string]) list.Add("e", "f", "g", "a", "b", "c", "d") - list.Sort(utils.StringComparator) + list.Sort(cmp.Compare[string]) for i := 1; i < list.Size(); i++ { a, _ := list.Get(i - 1) b, _ := list.Get(i) - if a.(string) > b.(string) { + if a > b { t.Errorf("Not sorted! %s > %s", a, b) } } } func TestListClear(t *testing.T) { - list := New() + list := New[string]() list.Add("e", "f", "g", "a", "b", "c", "d") list.Clear() if actualValue := list.Empty(); actualValue != true { @@ -154,12 +153,15 @@ func TestListClear(t *testing.T) { } func TestListContains(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Contains("a"); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } + if actualValue := list.Contains(""); actualValue != false { + t.Errorf("Got %v expected %v", actualValue, false) + } if actualValue := list.Contains("a", "b", "c"); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -176,43 +178,16 @@ func TestListContains(t *testing.T) { } func TestListValues(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } -func TestListIndexOf(t *testing.T) { - list := New() - - expectedIndex := -1 - if index := list.IndexOf("a"); index != expectedIndex { - t.Errorf("Got %v expected %v", index, expectedIndex) - } - - list.Add("a") - list.Add("b", "c") - - expectedIndex = 0 - if index := list.IndexOf("a"); index != expectedIndex { - t.Errorf("Got %v expected %v", index, expectedIndex) - } - - expectedIndex = 1 - if index := list.IndexOf("b"); index != expectedIndex { - t.Errorf("Got %v expected %v", index, expectedIndex) - } - - expectedIndex = 2 - if index := list.IndexOf("c"); index != expectedIndex { - t.Errorf("Got %v expected %v", index, expectedIndex) - } -} - func TestListInsert(t *testing.T) { - list := New() + list := New[string]() list.Insert(0, "b", "c", "d") list.Insert(0, "a") list.Insert(10, "x") // ignore @@ -223,20 +198,20 @@ func TestListInsert(t *testing.T) { if actualValue := list.Size(); actualValue != 5 { t.Errorf("Got %v expected %v", actualValue, 5) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s%s", list.Values()...), "abcdg"; actualValue != expectedValue { + if actualValue, expectedValue := strings.Join(list.Values(), ""), "abcdg"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } list.Insert(4, "e", "f") // last to first traversal if actualValue := list.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s%s%s%s", list.Values()...), "abcdefg"; actualValue != expectedValue { + if actualValue, expectedValue := strings.Join(list.Values(), ""), "abcdefg"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListSet(t *testing.T) { - list := New() + list := New[string]() list.Set(0, "a") list.Set(1, "b") if actualValue := list.Size(); actualValue != 2 { @@ -251,20 +226,15 @@ func TestListSet(t *testing.T) { if actualValue := list.Size(); actualValue != 3 { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abbc"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - list.Set(2, "cc") // last to first traversal - list.Set(0, "aa") // first to last traversal - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "aabbcc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "bb", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListEach(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - list.Each(func(index int, value interface{}) { + list.Each(func(index int, value string) { switch index { case 0: if actualValue, expectedValue := value, "a"; actualValue != expectedValue { @@ -285,10 +255,10 @@ func TestListEach(t *testing.T) { } func TestListMap(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - mappedList := list.Map(func(index int, value interface{}) interface{} { - return "mapped: " + value.(string) + mappedList := list.Map(func(index int, value string) string { + return "mapped: " + value }) if actualValue, _ := mappedList.Get(0); actualValue != "mapped: a" { t.Errorf("Got %v expected %v", actualValue, "mapped: a") @@ -305,10 +275,10 @@ func TestListMap(t *testing.T) { } func TestListSelect(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - selectedList := list.Select(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + selectedList := list.Select(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if actualValue, _ := selectedList.Get(0); actualValue != "a" { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -322,60 +292,60 @@ func TestListSelect(t *testing.T) { } func TestListAny(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - any := list.Any(func(index int, value interface{}) bool { - return value.(string) == "c" + any := list.Any(func(index int, value string) bool { + return value == "c" }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = list.Any(func(index int, value interface{}) bool { - return value.(string) == "x" + any = list.Any(func(index int, value string) bool { + return value == "x" }) if any != false { t.Errorf("Got %v expected %v", any, false) } } func TestListAll(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - all := list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "c" + all := list.All(func(index int, value string) bool { + return value >= "a" && value <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + all = list.All(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) } } func TestListFind(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - foundIndex, foundValue := list.Find(func(index int, value interface{}) bool { - return value.(string) == "c" + foundIndex, foundValue := list.Find(func(index int, value string) bool { + return value == "c" }) if foundValue != "c" || foundIndex != 2 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 2) } - foundIndex, foundValue = list.Find(func(index int, value interface{}) bool { - return value.(string) == "x" + foundIndex, foundValue = list.Find(func(index int, value string) bool { + return value == "x" }) - if foundValue != nil || foundIndex != -1 { + if foundValue != "" || foundIndex != -1 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) } } func TestListChaining(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - chainedList := list.Select(func(index int, value interface{}) bool { - return value.(string) > "a" - }).Map(func(index int, value interface{}) interface{} { - return value.(string) + value.(string) + chainedList := list.Select(func(index int, value string) bool { + return value > "a" + }).Map(func(index int, value string) string { + return value + value }) if chainedList.Size() != 2 { t.Errorf("Got %v expected %v", chainedList.Size(), 2) @@ -389,7 +359,7 @@ func TestListChaining(t *testing.T) { } func TestListIteratorNextOnEmpty(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty list") @@ -397,7 +367,7 @@ func TestListIteratorNextOnEmpty(t *testing.T) { } func TestListIteratorNext(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") it := list.Iterator() count := 0 @@ -428,7 +398,7 @@ func TestListIteratorNext(t *testing.T) { } func TestListIteratorPrevOnEmpty(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty list") @@ -436,7 +406,7 @@ func TestListIteratorPrevOnEmpty(t *testing.T) { } func TestListIteratorPrev(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") it := list.Iterator() for it.Next() { @@ -469,7 +439,7 @@ func TestListIteratorPrev(t *testing.T) { } func TestListIteratorBegin(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() it.Begin() list.Add("a", "b", "c") @@ -483,7 +453,7 @@ func TestListIteratorBegin(t *testing.T) { } func TestListIteratorEnd(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if index := it.Index(); index != -1 { @@ -508,7 +478,7 @@ func TestListIteratorEnd(t *testing.T) { } func TestListIteratorFirst(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -523,7 +493,7 @@ func TestListIteratorFirst(t *testing.T) { } func TestListIteratorLast(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -539,13 +509,13 @@ func TestListIteratorLast(t *testing.T) { func TestListIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - list := New() + list := New[string]() it := list.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") @@ -554,7 +524,7 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (not found) { - list := New() + list := New[string]() list.Add("xx", "yy") it := list.Iterator() for it.NextTo(seek) { @@ -564,20 +534,20 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (found) { - list := New() + list := New[string]() list.Add("aa", "bb", "cc") it := list.Iterator() it.Begin() if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -588,13 +558,13 @@ func TestListIteratorNextTo(t *testing.T) { func TestListIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - list := New() + list := New[string]() it := list.Iterator() it.End() for it.PrevTo(seek) { @@ -604,7 +574,7 @@ func TestListIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - list := New() + list := New[string]() list.Add("xx", "yy") it := list.Iterator() it.End() @@ -615,20 +585,20 @@ func TestListIteratorPrevTo(t *testing.T) { // PrevTo (found) { - list := New() + list := New[string]() list.Add("aa", "bb", "cc") it := list.Iterator() it.End() if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -638,12 +608,12 @@ func TestListIteratorPrevTo(t *testing.T) { } func TestListSerialization(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue, expectedValue := list.Size(), 3; actualValue != expectedValue { @@ -662,26 +632,27 @@ func TestListSerialization(t *testing.T) { err = list.FromJSON(bytes) assert() - bytes, err = json.Marshal([]interface{}{"a", "b", "c", list}) + bytes, err = json.Marshal([]any{"a", "b", "c", list}) if err != nil { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &list) + err = json.Unmarshal([]byte(`["a","b","c"]`), &list) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestListString(t *testing.T) { - c := New() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "DoublyLinkedList") { t.Errorf("String should start with container name") } } -func benchmarkGet(b *testing.B, list *List, size int) { +func benchmarkGet(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Get(n) @@ -689,7 +660,7 @@ func benchmarkGet(b *testing.B, list *List, size int) { } } -func benchmarkAdd(b *testing.B, list *List, size int) { +func benchmarkAdd(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Add(n) @@ -697,7 +668,7 @@ func benchmarkAdd(b *testing.B, list *List, size int) { } } -func benchmarkRemove(b *testing.B, list *List, size int) { +func benchmarkRemove(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Remove(n) @@ -708,7 +679,7 @@ func benchmarkRemove(b *testing.B, list *List, size int) { func BenchmarkDoublyLinkedListGet100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -719,7 +690,7 @@ func BenchmarkDoublyLinkedListGet100(b *testing.B) { func BenchmarkDoublyLinkedListGet1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -730,7 +701,7 @@ func BenchmarkDoublyLinkedListGet1000(b *testing.B) { func BenchmarkDoublyLinkedListGet10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -741,7 +712,7 @@ func BenchmarkDoublyLinkedListGet10000(b *testing.B) { func BenchmarkDoublyLinkedListGet100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -752,7 +723,7 @@ func BenchmarkDoublyLinkedListGet100000(b *testing.B) { func BenchmarkDoublyLinkedListAdd100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() b.StartTimer() benchmarkAdd(b, list, size) } @@ -760,7 +731,7 @@ func BenchmarkDoublyLinkedListAdd100(b *testing.B) { func BenchmarkDoublyLinkedListAdd1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -771,7 +742,7 @@ func BenchmarkDoublyLinkedListAdd1000(b *testing.B) { func BenchmarkDoublyLinkedListAdd10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -782,7 +753,7 @@ func BenchmarkDoublyLinkedListAdd10000(b *testing.B) { func BenchmarkDoublyLinkedListAdd100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -793,7 +764,7 @@ func BenchmarkDoublyLinkedListAdd100000(b *testing.B) { func BenchmarkDoublyLinkedListRemove100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -804,7 +775,7 @@ func BenchmarkDoublyLinkedListRemove100(b *testing.B) { func BenchmarkDoublyLinkedListRemove1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -815,7 +786,7 @@ func BenchmarkDoublyLinkedListRemove1000(b *testing.B) { func BenchmarkDoublyLinkedListRemove10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -826,7 +797,7 @@ func BenchmarkDoublyLinkedListRemove10000(b *testing.B) { func BenchmarkDoublyLinkedListRemove100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } diff --git a/lists/doublylinkedlist/enumerable.go b/lists/doublylinkedlist/enumerable.go index 4b14a47f..4e6e7784 100644 --- a/lists/doublylinkedlist/enumerable.go +++ b/lists/doublylinkedlist/enumerable.go @@ -4,13 +4,13 @@ package doublylinkedlist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithIndex = (*List)(nil) +var _ containers.EnumerableWithIndex[int] = (*List[int])(nil) // Each calls the given function once for each element, passing that element's index and value. -func (list *List) Each(f func(index int, value interface{})) { +func (list *List[T]) Each(f func(index int, value T)) { iterator := list.Iterator() for iterator.Next() { f(iterator.Index(), iterator.Value()) @@ -19,8 +19,8 @@ func (list *List) Each(f func(index int, value interface{})) { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. -func (list *List) Map(f func(index int, value interface{}) interface{}) *List { - newList := &List{} +func (list *List[T]) Map(f func(index int, value T) T) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { newList.Add(f(iterator.Index(), iterator.Value())) @@ -29,8 +29,8 @@ func (list *List) Map(f func(index int, value interface{}) interface{}) *List { } // Select returns a new container containing all elements for which the given function returns a true value. -func (list *List) Select(f func(index int, value interface{}) bool) *List { - newList := &List{} +func (list *List[T]) Select(f func(index int, value T) bool) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -42,7 +42,7 @@ func (list *List) Select(f func(index int, value interface{}) bool) *List { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (list *List) Any(f func(index int, value interface{}) bool) bool { +func (list *List[T]) Any(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -54,7 +54,7 @@ func (list *List) Any(f func(index int, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (list *List) All(f func(index int, value interface{}) bool) bool { +func (list *List[T]) All(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if !f(iterator.Index(), iterator.Value()) { @@ -67,12 +67,13 @@ func (list *List) All(f func(index int, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. -func (list *List) Find(f func(index int, value interface{}) bool) (index int, value interface{}) { +func (list *List[T]) Find(f func(index int, value T) bool) (index int, value T) { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { return iterator.Index(), iterator.Value() } } - return -1, nil + var t T + return -1, t } diff --git a/lists/doublylinkedlist/iterator.go b/lists/doublylinkedlist/iterator.go index 27b34c2e..1b637073 100644 --- a/lists/doublylinkedlist/iterator.go +++ b/lists/doublylinkedlist/iterator.go @@ -4,28 +4,28 @@ package doublylinkedlist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator holding the iterator's state -type Iterator struct { - list *List +type Iterator[T comparable] struct { + list *List[T] index int - element *element + element *element[T] } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (list *List) Iterator() Iterator { - return Iterator{list: list, index: -1, element: nil} +func (list *List[T]) Iterator() Iterator[T] { + return Iterator[T]{list: list, index: -1, element: nil} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.list.size { iterator.index++ } @@ -44,7 +44,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -62,26 +62,26 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.element.value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 iterator.element = nil } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.list.size iterator.element = iterator.list.last } @@ -89,7 +89,7 @@ func (iterator *Iterator) End() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -97,7 +97,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -106,7 +106,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -120,7 +120,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/lists/doublylinkedlist/serialization.go b/lists/doublylinkedlist/serialization.go index f210f9a2..38d95bf2 100644 --- a/lists/doublylinkedlist/serialization.go +++ b/lists/doublylinkedlist/serialization.go @@ -6,21 +6,22 @@ package doublylinkedlist import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*List)(nil) -var _ containers.JSONDeserializer = (*List)(nil) +var _ containers.JSONSerializer = (*List[int])(nil) +var _ containers.JSONDeserializer = (*List[int])(nil) // ToJSON outputs the JSON representation of list's elements. -func (list *List) ToJSON() ([]byte, error) { +func (list *List[T]) ToJSON() ([]byte, error) { return json.Marshal(list.Values()) } // FromJSON populates list's elements from the input JSON representation. -func (list *List) FromJSON(data []byte) error { - elements := []interface{}{} +func (list *List[T]) FromJSON(data []byte) error { + var elements []T err := json.Unmarshal(data, &elements) if err == nil { list.Clear() @@ -30,11 +31,11 @@ func (list *List) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (list *List) UnmarshalJSON(bytes []byte) error { +func (list *List[T]) UnmarshalJSON(bytes []byte) error { return list.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (list *List) MarshalJSON() ([]byte, error) { +func (list *List[T]) MarshalJSON() ([]byte, error) { return list.ToJSON() } diff --git a/lists/lists.go b/lists/lists.go index 55bd619e..cc310366 100644 --- a/lists/lists.go +++ b/lists/lists.go @@ -10,22 +10,22 @@ package lists import ( - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + "github.com/emirpasic/gods/v2/containers" + "github.com/emirpasic/gods/v2/utils" ) // List interface that all lists implement -type List interface { - Get(index int) (interface{}, bool) +type List[T comparable] interface { + Get(index int) (T, bool) Remove(index int) - Add(values ...interface{}) - Contains(values ...interface{}) bool - Sort(comparator utils.Comparator) + Add(values ...T) + Contains(values ...T) bool + Sort(comparator utils.Comparator[T]) Swap(index1, index2 int) - Insert(index int, values ...interface{}) - Set(index int, value interface{}) + Insert(index int, values ...T) + Set(index int, value T) - containers.Container + containers.Container[T] // Empty() bool // Size() int // Clear() diff --git a/lists/singlylinkedlist/enumerable.go b/lists/singlylinkedlist/enumerable.go index 6fdbcb8b..febb90d4 100644 --- a/lists/singlylinkedlist/enumerable.go +++ b/lists/singlylinkedlist/enumerable.go @@ -4,13 +4,13 @@ package singlylinkedlist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithIndex = (*List)(nil) +var _ containers.EnumerableWithIndex[int] = (*List[int])(nil) // Each calls the given function once for each element, passing that element's index and value. -func (list *List) Each(f func(index int, value interface{})) { +func (list *List[T]) Each(f func(index int, value T)) { iterator := list.Iterator() for iterator.Next() { f(iterator.Index(), iterator.Value()) @@ -19,8 +19,8 @@ func (list *List) Each(f func(index int, value interface{})) { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. -func (list *List) Map(f func(index int, value interface{}) interface{}) *List { - newList := &List{} +func (list *List[T]) Map(f func(index int, value T) T) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { newList.Add(f(iterator.Index(), iterator.Value())) @@ -29,8 +29,8 @@ func (list *List) Map(f func(index int, value interface{}) interface{}) *List { } // Select returns a new container containing all elements for which the given function returns a true value. -func (list *List) Select(f func(index int, value interface{}) bool) *List { - newList := &List{} +func (list *List[T]) Select(f func(index int, value T) bool) *List[T] { + newList := &List[T]{} iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -42,7 +42,7 @@ func (list *List) Select(f func(index int, value interface{}) bool) *List { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (list *List) Any(f func(index int, value interface{}) bool) bool { +func (list *List[T]) Any(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -54,7 +54,7 @@ func (list *List) Any(f func(index int, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (list *List) All(f func(index int, value interface{}) bool) bool { +func (list *List[T]) All(f func(index int, value T) bool) bool { iterator := list.Iterator() for iterator.Next() { if !f(iterator.Index(), iterator.Value()) { @@ -67,12 +67,12 @@ func (list *List) All(f func(index int, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. -func (list *List) Find(f func(index int, value interface{}) bool) (index int, value interface{}) { +func (list *List[T]) Find(f func(index int, value T) bool) (index int, value T) { iterator := list.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { return iterator.Index(), iterator.Value() } } - return -1, nil + return -1, value } diff --git a/lists/singlylinkedlist/iterator.go b/lists/singlylinkedlist/iterator.go index 4e7f1773..dc3304a8 100644 --- a/lists/singlylinkedlist/iterator.go +++ b/lists/singlylinkedlist/iterator.go @@ -4,28 +4,28 @@ package singlylinkedlist -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.IteratorWithIndex = (*Iterator)(nil) +var _ containers.IteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator holding the iterator's state -type Iterator struct { - list *List +type Iterator[T comparable] struct { + list *List[T] index int - element *element + element *element[T] } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (list *List) Iterator() Iterator { - return Iterator{list: list, index: -1, element: nil} +func (list *List[T]) Iterator() *Iterator[T] { + return &Iterator[T]{list: list, index: -1, element: nil} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.list.size { iterator.index++ } @@ -43,19 +43,19 @@ func (iterator *Iterator) Next() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.element.value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 iterator.element = nil } @@ -63,7 +63,7 @@ func (iterator *Iterator) Begin() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -72,7 +72,7 @@ func (iterator *Iterator) First() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/lists/singlylinkedlist/serialization.go b/lists/singlylinkedlist/serialization.go index 588a316c..3cc49652 100644 --- a/lists/singlylinkedlist/serialization.go +++ b/lists/singlylinkedlist/serialization.go @@ -6,21 +6,22 @@ package singlylinkedlist import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*List)(nil) -var _ containers.JSONDeserializer = (*List)(nil) +var _ containers.JSONSerializer = (*List[int])(nil) +var _ containers.JSONDeserializer = (*List[int])(nil) // ToJSON outputs the JSON representation of list's elements. -func (list *List) ToJSON() ([]byte, error) { +func (list *List[T]) ToJSON() ([]byte, error) { return json.Marshal(list.Values()) } // FromJSON populates list's elements from the input JSON representation. -func (list *List) FromJSON(data []byte) error { - elements := []interface{}{} +func (list *List[T]) FromJSON(data []byte) error { + var elements []T err := json.Unmarshal(data, &elements) if err == nil { list.Clear() @@ -30,11 +31,11 @@ func (list *List) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (list *List) UnmarshalJSON(bytes []byte) error { +func (list *List[T]) UnmarshalJSON(bytes []byte) error { return list.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (list *List) MarshalJSON() ([]byte, error) { +func (list *List[T]) MarshalJSON() ([]byte, error) { return list.ToJSON() } diff --git a/lists/singlylinkedlist/singlylinkedlist.go b/lists/singlylinkedlist/singlylinkedlist.go index c3e2c675..a4126021 100644 --- a/lists/singlylinkedlist/singlylinkedlist.go +++ b/lists/singlylinkedlist/singlylinkedlist.go @@ -11,30 +11,31 @@ package singlylinkedlist import ( "fmt" + "slices" "strings" - "github.com/emirpasic/gods/lists" - "github.com/emirpasic/gods/utils" + "github.com/emirpasic/gods/v2/lists" + "github.com/emirpasic/gods/v2/utils" ) // Assert List implementation -var _ lists.List = (*List)(nil) +var _ lists.List[int] = (*List[int])(nil) // List holds the elements, where each element points to the next element -type List struct { - first *element - last *element +type List[T comparable] struct { + first *element[T] + last *element[T] size int } -type element struct { - value interface{} - next *element +type element[T comparable] struct { + value T + next *element[T] } // New instantiates a new list and adds the passed values, if any, to the list -func New(values ...interface{}) *List { - list := &List{} +func New[T comparable](values ...T) *List[T] { + list := &List[T]{} if len(values) > 0 { list.Add(values...) } @@ -42,9 +43,9 @@ func New(values ...interface{}) *List { } // Add appends a value (one or more) at the end of the list (same as Append()) -func (list *List) Add(values ...interface{}) { +func (list *List[T]) Add(values ...T) { for _, value := range values { - newElement := &element{value: value} + newElement := &element[T]{value: value} if list.size == 0 { list.first = newElement list.last = newElement @@ -57,15 +58,15 @@ func (list *List) Add(values ...interface{}) { } // Append appends a value (one or more) at the end of the list (same as Add()) -func (list *List) Append(values ...interface{}) { +func (list *List[T]) Append(values ...T) { list.Add(values...) } // Prepend prepends a values (or more) -func (list *List) Prepend(values ...interface{}) { +func (list *List[T]) Prepend(values ...T) { // in reverse to keep passed order i.e. ["c","d"] -> Prepend(["a","b"]) -> ["a","b","c",d"] for v := len(values) - 1; v >= 0; v-- { - newElement := &element{value: values[v], next: list.first} + newElement := &element[T]{value: values[v], next: list.first} list.first = newElement if list.size == 0 { list.last = newElement @@ -76,10 +77,11 @@ func (list *List) Prepend(values ...interface{}) { // Get returns the element at index. // Second return parameter is true if index is within bounds of the array and array is not empty, otherwise false. -func (list *List) Get(index int) (interface{}, bool) { +func (list *List[T]) Get(index int) (T, bool) { if !list.withinRange(index) { - return nil, false + var t T + return t, false } element := list.first @@ -90,7 +92,7 @@ func (list *List) Get(index int) (interface{}, bool) { } // Remove removes the element at the given index from the list. -func (list *List) Remove(index int) { +func (list *List[T]) Remove(index int) { if !list.withinRange(index) { return @@ -101,7 +103,7 @@ func (list *List) Remove(index int) { return } - var beforeElement *element + var beforeElement *element[T] element := list.first for e := 0; e != index; e, element = e+1, element.next { beforeElement = element @@ -126,7 +128,7 @@ func (list *List) Remove(index int) { // All values have to be present in the set for the method to return true. // Performance time complexity of n^2. // Returns true if no arguments are passed at all, i.e. set is always super-set of empty set. -func (list *List) Contains(values ...interface{}) bool { +func (list *List[T]) Contains(values ...T) bool { if len(values) == 0 { return true @@ -150,16 +152,16 @@ func (list *List) Contains(values ...interface{}) bool { } // Values returns all elements in the list. -func (list *List) Values() []interface{} { - values := make([]interface{}, list.size, list.size) +func (list *List[T]) Values() []T { + values := make([]T, list.size, list.size) for e, element := 0, list.first; element != nil; e, element = e+1, element.next { values[e] = element.value } return values } -//IndexOf returns index of provided element -func (list *List) IndexOf(value interface{}) int { +// IndexOf returns index of provided element +func (list *List[T]) IndexOf(value T) int { if list.size == 0 { return -1 } @@ -172,31 +174,31 @@ func (list *List) IndexOf(value interface{}) int { } // Empty returns true if list does not contain any elements. -func (list *List) Empty() bool { +func (list *List[T]) Empty() bool { return list.size == 0 } // Size returns number of elements within the list. -func (list *List) Size() int { +func (list *List[T]) Size() int { return list.size } // Clear removes all elements from the list. -func (list *List) Clear() { +func (list *List[T]) Clear() { list.size = 0 list.first = nil list.last = nil } // Sort sort values (in-place) using. -func (list *List) Sort(comparator utils.Comparator) { +func (list *List[T]) Sort(comparator utils.Comparator[T]) { if list.size < 2 { return } values := list.Values() - utils.Sort(values, comparator) + slices.SortFunc(values, comparator) list.Clear() @@ -205,9 +207,9 @@ func (list *List) Sort(comparator utils.Comparator) { } // Swap swaps values of two elements at the given indices. -func (list *List) Swap(i, j int) { +func (list *List[T]) Swap(i, j int) { if list.withinRange(i) && list.withinRange(j) && i != j { - var element1, element2 *element + var element1, element2 *element[T] for e, currentElement := 0, list.first; element1 == nil || element2 == nil; e, currentElement = e+1, currentElement.next { switch e { case i: @@ -223,7 +225,7 @@ func (list *List) Swap(i, j int) { // Insert inserts values at specified index position shifting the value at that position (if any) and any subsequent elements to the right. // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Insert(index int, values ...interface{}) { +func (list *List[T]) Insert(index int, values ...T) { if !list.withinRange(index) { // Append @@ -235,7 +237,7 @@ func (list *List) Insert(index int, values ...interface{}) { list.size += len(values) - var beforeElement *element + var beforeElement *element[T] foundElement := list.first for e := 0; e != index; e, foundElement = e+1, foundElement.next { beforeElement = foundElement @@ -244,7 +246,7 @@ func (list *List) Insert(index int, values ...interface{}) { if foundElement == list.first { oldNextElement := list.first for i, value := range values { - newElement := &element{value: value} + newElement := &element[T]{value: value} if i == 0 { list.first = newElement } else { @@ -256,7 +258,7 @@ func (list *List) Insert(index int, values ...interface{}) { } else { oldNextElement := beforeElement.next for _, value := range values { - newElement := &element{value: value} + newElement := &element[T]{value: value} beforeElement.next = newElement beforeElement = newElement } @@ -267,7 +269,7 @@ func (list *List) Insert(index int, values ...interface{}) { // Set value at specified index // Does not do anything if position is negative or bigger than list's size // Note: position equal to list's size is valid, i.e. append. -func (list *List) Set(index int, value interface{}) { +func (list *List[T]) Set(index int, value T) { if !list.withinRange(index) { // Append @@ -285,7 +287,7 @@ func (list *List) Set(index int, value interface{}) { } // String returns a string representation of container -func (list *List) String() string { +func (list *List[T]) String() string { str := "SinglyLinkedList\n" values := []string{} for element := list.first; element != nil; element = element.next { @@ -296,6 +298,6 @@ func (list *List) String() string { } // Check that the index is within bounds of the list -func (list *List) withinRange(index int) bool { +func (list *List[T]) withinRange(index int) bool { return index >= 0 && index < list.size } diff --git a/lists/singlylinkedlist/singlylinkedlist_test.go b/lists/singlylinkedlist/singlylinkedlist_test.go index 4d58b7d5..6dd1d944 100644 --- a/lists/singlylinkedlist/singlylinkedlist_test.go +++ b/lists/singlylinkedlist/singlylinkedlist_test.go @@ -5,22 +5,21 @@ package singlylinkedlist import ( + "cmp" "encoding/json" - "fmt" + "slices" "strings" "testing" - - "github.com/emirpasic/gods/utils" ) func TestListNew(t *testing.T) { - list1 := New() + list1 := New[int]() if actualValue := list1.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } - list2 := New(1, "b") + list2 := New[int](1, 2) if actualValue := list2.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) @@ -30,17 +29,17 @@ func TestListNew(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, 1) } - if actualValue, ok := list2.Get(1); actualValue != "b" || !ok { - t.Errorf("Got %v expected %v", actualValue, "b") + if actualValue, ok := list2.Get(1); actualValue != 2 || !ok { + t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, ok := list2.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list2.Get(2); actualValue != 0 || ok { + t.Errorf("Got %v expected %v", actualValue, 0) } } func TestListAdd(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Empty(); actualValue != false { @@ -55,7 +54,7 @@ func TestListAdd(t *testing.T) { } func TestListAppendAndPrepend(t *testing.T) { - list := New() + list := New[string]() list.Add("b") list.Prepend("a") list.Append("c") @@ -77,12 +76,12 @@ func TestListAppendAndPrepend(t *testing.T) { } func TestListRemove(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Remove(2) - if actualValue, ok := list.Get(2); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(2); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(1) list.Remove(0) @@ -96,7 +95,7 @@ func TestListRemove(t *testing.T) { } func TestListGet(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue, ok := list.Get(0); actualValue != "a" || !ok { @@ -108,8 +107,8 @@ func TestListGet(t *testing.T) { if actualValue, ok := list.Get(2); actualValue != "c" || !ok { t.Errorf("Got %v expected %v", actualValue, "c") } - if actualValue, ok := list.Get(3); actualValue != nil || ok { - t.Errorf("Got %v expected %v", actualValue, nil) + if actualValue, ok := list.Get(3); actualValue != "" || ok { + t.Errorf("Got %v expected %v", actualValue, "") } list.Remove(0) if actualValue, ok := list.Get(0); actualValue != "b" || !ok { @@ -118,31 +117,31 @@ func TestListGet(t *testing.T) { } func TestListSwap(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") list.Swap(0, 1) if actualValue, ok := list.Get(0); actualValue != "b" || !ok { - t.Errorf("Got %v expected %v", actualValue, "c") + t.Errorf("Got %v expected %v", actualValue, "b") } } func TestListSort(t *testing.T) { - list := New() - list.Sort(utils.StringComparator) + list := New[string]() + list.Sort(cmp.Compare[string]) list.Add("e", "f", "g", "a", "b", "c", "d") - list.Sort(utils.StringComparator) + list.Sort(cmp.Compare[string]) for i := 1; i < list.Size(); i++ { a, _ := list.Get(i - 1) b, _ := list.Get(i) - if a.(string) > b.(string) { + if a > b { t.Errorf("Not sorted! %s > %s", a, b) } } } func TestListClear(t *testing.T) { - list := New() + list := New[string]() list.Add("e", "f", "g", "a", "b", "c", "d") list.Clear() if actualValue := list.Empty(); actualValue != true { @@ -154,12 +153,15 @@ func TestListClear(t *testing.T) { } func TestListContains(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") if actualValue := list.Contains("a"); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } + if actualValue := list.Contains(""); actualValue != false { + t.Errorf("Got %v expected %v", actualValue, false) + } if actualValue := list.Contains("a", "b", "c"); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -176,16 +178,16 @@ func TestListContains(t *testing.T) { } func TestListValues(t *testing.T) { - list := New() + list := New[string]() list.Add("a") list.Add("b", "c") - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListIndexOf(t *testing.T) { - list := New() + list := New[string]() expectedIndex := -1 if index := list.IndexOf("a"); index != expectedIndex { @@ -212,7 +214,7 @@ func TestListIndexOf(t *testing.T) { } func TestListInsert(t *testing.T) { - list := New() + list := New[string]() list.Insert(0, "b", "c") list.Insert(0, "a") list.Insert(10, "x") // ignore @@ -223,13 +225,13 @@ func TestListInsert(t *testing.T) { if actualValue := list.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", list.Values()...), "abcd"; actualValue != expectedValue { + if actualValue, expectedValue := strings.Join(list.Values(), ""), "abcd"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListSet(t *testing.T) { - list := New() + list := New[string]() list.Set(0, "a") list.Set(1, "b") if actualValue := list.Size(); actualValue != 2 { @@ -244,15 +246,15 @@ func TestListSet(t *testing.T) { if actualValue := list.Size(); actualValue != 3 { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abbc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "bb", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestListEach(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - list.Each(func(index int, value interface{}) { + list.Each(func(index int, value string) { switch index { case 0: if actualValue, expectedValue := value, "a"; actualValue != expectedValue { @@ -273,10 +275,10 @@ func TestListEach(t *testing.T) { } func TestListMap(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - mappedList := list.Map(func(index int, value interface{}) interface{} { - return "mapped: " + value.(string) + mappedList := list.Map(func(index int, value string) string { + return "mapped: " + value }) if actualValue, _ := mappedList.Get(0); actualValue != "mapped: a" { t.Errorf("Got %v expected %v", actualValue, "mapped: a") @@ -293,10 +295,10 @@ func TestListMap(t *testing.T) { } func TestListSelect(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - selectedList := list.Select(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + selectedList := list.Select(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if actualValue, _ := selectedList.Get(0); actualValue != "a" { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -310,60 +312,60 @@ func TestListSelect(t *testing.T) { } func TestListAny(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - any := list.Any(func(index int, value interface{}) bool { - return value.(string) == "c" + any := list.Any(func(index int, value string) bool { + return value == "c" }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = list.Any(func(index int, value interface{}) bool { - return value.(string) == "x" + any = list.Any(func(index int, value string) bool { + return value == "x" }) if any != false { t.Errorf("Got %v expected %v", any, false) } } func TestListAll(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - all := list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "c" + all := list.All(func(index int, value string) bool { + return value >= "a" && value <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = list.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + all = list.All(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) } } func TestListFind(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - foundIndex, foundValue := list.Find(func(index int, value interface{}) bool { - return value.(string) == "c" + foundIndex, foundValue := list.Find(func(index int, value string) bool { + return value == "c" }) if foundValue != "c" || foundIndex != 2 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 2) } - foundIndex, foundValue = list.Find(func(index int, value interface{}) bool { - return value.(string) == "x" + foundIndex, foundValue = list.Find(func(index int, value string) bool { + return value == "x" }) - if foundValue != nil || foundIndex != -1 { + if foundValue != "" || foundIndex != -1 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) } } func TestListChaining(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") - chainedList := list.Select(func(index int, value interface{}) bool { - return value.(string) > "a" - }).Map(func(index int, value interface{}) interface{} { - return value.(string) + value.(string) + chainedList := list.Select(func(index int, value string) bool { + return value > "a" + }).Map(func(index int, value string) string { + return value + value }) if chainedList.Size() != 2 { t.Errorf("Got %v expected %v", chainedList.Size(), 2) @@ -377,7 +379,7 @@ func TestListChaining(t *testing.T) { } func TestListIteratorNextOnEmpty(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty list") @@ -385,7 +387,7 @@ func TestListIteratorNextOnEmpty(t *testing.T) { } func TestListIteratorNext(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") it := list.Iterator() count := 0 @@ -416,7 +418,7 @@ func TestListIteratorNext(t *testing.T) { } func TestListIteratorBegin(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() it.Begin() list.Add("a", "b", "c") @@ -430,7 +432,7 @@ func TestListIteratorBegin(t *testing.T) { } func TestListIteratorFirst(t *testing.T) { - list := New() + list := New[string]() it := list.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -446,13 +448,13 @@ func TestListIteratorFirst(t *testing.T) { func TestListIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - list := New() + list := New[string]() it := list.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") @@ -461,7 +463,7 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (not found) { - list := New() + list := New[string]() list.Add("xx", "yy") it := list.Iterator() for it.NextTo(seek) { @@ -471,20 +473,20 @@ func TestListIteratorNextTo(t *testing.T) { // NextTo (found) { - list := New() + list := New[string]() list.Add("aa", "bb", "cc") it := list.Iterator() it.Begin() if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -494,12 +496,12 @@ func TestListIteratorNextTo(t *testing.T) { } func TestListSerialization(t *testing.T) { - list := New() + list := New[string]() list.Add("a", "b", "c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", list.Values()...), "abc"; actualValue != expectedValue { + if actualValue, expectedValue := list.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue, expectedValue := list.Size(), 3; actualValue != expectedValue { @@ -518,26 +520,27 @@ func TestListSerialization(t *testing.T) { err = list.FromJSON(bytes) assert() - bytes, err = json.Marshal([]interface{}{"a", "b", "c", list}) + bytes, err = json.Marshal([]any{"a", "b", "c", list}) if err != nil { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &list) + err = json.Unmarshal([]byte(`["a","b","c"]`), &list) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestListString(t *testing.T) { - c := New() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "SinglyLinkedList") { t.Errorf("String should start with container name") } } -func benchmarkGet(b *testing.B, list *List, size int) { +func benchmarkGet(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Get(n) @@ -545,7 +548,7 @@ func benchmarkGet(b *testing.B, list *List, size int) { } } -func benchmarkAdd(b *testing.B, list *List, size int) { +func benchmarkAdd(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Add(n) @@ -553,7 +556,7 @@ func benchmarkAdd(b *testing.B, list *List, size int) { } } -func benchmarkRemove(b *testing.B, list *List, size int) { +func benchmarkRemove(b *testing.B, list *List[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { list.Remove(n) @@ -564,7 +567,7 @@ func benchmarkRemove(b *testing.B, list *List, size int) { func BenchmarkSinglyLinkedListGet100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -575,7 +578,7 @@ func BenchmarkSinglyLinkedListGet100(b *testing.B) { func BenchmarkSinglyLinkedListGet1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -586,7 +589,7 @@ func BenchmarkSinglyLinkedListGet1000(b *testing.B) { func BenchmarkSinglyLinkedListGet10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -597,7 +600,7 @@ func BenchmarkSinglyLinkedListGet10000(b *testing.B) { func BenchmarkSinglyLinkedListGet100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -608,7 +611,7 @@ func BenchmarkSinglyLinkedListGet100000(b *testing.B) { func BenchmarkSinglyLinkedListAdd100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() b.StartTimer() benchmarkAdd(b, list, size) } @@ -616,7 +619,7 @@ func BenchmarkSinglyLinkedListAdd100(b *testing.B) { func BenchmarkSinglyLinkedListAdd1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -627,7 +630,7 @@ func BenchmarkSinglyLinkedListAdd1000(b *testing.B) { func BenchmarkSinglyLinkedListAdd10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -638,7 +641,7 @@ func BenchmarkSinglyLinkedListAdd10000(b *testing.B) { func BenchmarkSinglyLinkedListAdd100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -649,7 +652,7 @@ func BenchmarkSinglyLinkedListAdd100000(b *testing.B) { func BenchmarkSinglyLinkedListRemove100(b *testing.B) { b.StopTimer() size := 100 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -660,7 +663,7 @@ func BenchmarkSinglyLinkedListRemove100(b *testing.B) { func BenchmarkSinglyLinkedListRemove1000(b *testing.B) { b.StopTimer() size := 1000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -671,7 +674,7 @@ func BenchmarkSinglyLinkedListRemove1000(b *testing.B) { func BenchmarkSinglyLinkedListRemove10000(b *testing.B) { b.StopTimer() size := 10000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } @@ -682,7 +685,7 @@ func BenchmarkSinglyLinkedListRemove10000(b *testing.B) { func BenchmarkSinglyLinkedListRemove100000(b *testing.B) { b.StopTimer() size := 100000 - list := New() + list := New[int]() for n := 0; n < size; n++ { list.Add(n) } diff --git a/maps/hashbidimap/hashbidimap.go b/maps/hashbidimap/hashbidimap.go index 5a386ec7..cf947a82 100644 --- a/maps/hashbidimap/hashbidimap.go +++ b/maps/hashbidimap/hashbidimap.go @@ -17,26 +17,27 @@ package hashbidimap import ( "fmt" - "github.com/emirpasic/gods/maps" - "github.com/emirpasic/gods/maps/hashmap" + + "github.com/emirpasic/gods/v2/maps" + "github.com/emirpasic/gods/v2/maps/hashmap" ) // Assert Map implementation -var _ maps.BidiMap = (*Map)(nil) +var _ maps.BidiMap[string, int] = (*Map[string, int])(nil) // Map holds the elements in two hashmaps. -type Map struct { - forwardMap hashmap.Map - inverseMap hashmap.Map +type Map[K, V comparable] struct { + forwardMap hashmap.Map[K, V] + inverseMap hashmap.Map[V, K] } // New instantiates a bidirectional map. -func New() *Map { - return &Map{*hashmap.New(), *hashmap.New()} +func New[K, V comparable]() *Map[K, V] { + return &Map[K, V]{*hashmap.New[K, V](), *hashmap.New[V, K]()} } // Put inserts element into the map. -func (m *Map) Put(key interface{}, value interface{}) { +func (m *Map[K, V]) Put(key K, value V) { if valueByKey, ok := m.forwardMap.Get(key); ok { m.inverseMap.Remove(valueByKey) } @@ -49,18 +50,18 @@ func (m *Map) Put(key interface{}, value interface{}) { // Get searches the element in the map by key and returns its value or nil if key is not found in map. // Second return parameter is true if key was found, otherwise false. -func (m *Map) Get(key interface{}) (value interface{}, found bool) { +func (m *Map[K, V]) Get(key K) (value V, found bool) { return m.forwardMap.Get(key) } // GetKey searches the element in the map by value and returns its key or nil if value is not found in map. // Second return parameter is true if value was found, otherwise false. -func (m *Map) GetKey(value interface{}) (key interface{}, found bool) { +func (m *Map[K, V]) GetKey(value V) (key K, found bool) { return m.inverseMap.Get(value) } // Remove removes the element from the map by key. -func (m *Map) Remove(key interface{}) { +func (m *Map[K, V]) Remove(key K) { if value, found := m.forwardMap.Get(key); found { m.forwardMap.Remove(key) m.inverseMap.Remove(value) @@ -68,33 +69,33 @@ func (m *Map) Remove(key interface{}) { } // Empty returns true if map does not contain any elements -func (m *Map) Empty() bool { +func (m *Map[K, V]) Empty() bool { return m.Size() == 0 } // Size returns number of elements in the map. -func (m *Map) Size() int { +func (m *Map[K, V]) Size() int { return m.forwardMap.Size() } // Keys returns all keys (random order). -func (m *Map) Keys() []interface{} { +func (m *Map[K, V]) Keys() []K { return m.forwardMap.Keys() } // Values returns all values (random order). -func (m *Map) Values() []interface{} { +func (m *Map[K, V]) Values() []V { return m.inverseMap.Keys() } // Clear removes all elements from the map. -func (m *Map) Clear() { +func (m *Map[K, V]) Clear() { m.forwardMap.Clear() m.inverseMap.Clear() } // String returns a string representation of container -func (m *Map) String() string { +func (m *Map[K, V]) String() string { str := "HashBidiMap\n" str += fmt.Sprintf("%v", m.forwardMap) return str diff --git a/maps/hashbidimap/hashbidimap_test.go b/maps/hashbidimap/hashbidimap_test.go index dd911165..64fae28f 100644 --- a/maps/hashbidimap/hashbidimap_test.go +++ b/maps/hashbidimap/hashbidimap_test.go @@ -6,13 +6,14 @@ package hashbidimap import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestMapPut(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -25,12 +26,8 @@ func TestMapPut(t *testing.T) { if actualValue := m.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4, 5, 6, 7}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d", "e", "f", "g"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4, 5, 6, 7}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}) // key,expectedValue,expectedFound tests1 := [][]interface{}{ @@ -41,12 +38,12 @@ func TestMapPut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -54,7 +51,7 @@ func TestMapPut(t *testing.T) { } func TestMapRemove(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -70,13 +67,9 @@ func TestMapRemove(t *testing.T) { m.Remove(8) m.Remove(5) - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d"}) - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } if actualValue := m.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } @@ -86,14 +79,14 @@ func TestMapRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -106,12 +99,8 @@ func TestMapRemove(t *testing.T) { m.Remove(2) m.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", m.Keys()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s", m.Values()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), nil) + testutils.SameElements(t, m.Values(), nil) if actualValue := m.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) } @@ -121,7 +110,7 @@ func TestMapRemove(t *testing.T) { } func TestMapGetKey(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -140,12 +129,12 @@ func TestMapGetKey(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {nil, "x", false}, + {0, "x", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.GetKey(test[1]) + actualValue, actualFound := m.GetKey(test[1].(string)) if actualValue != test[0] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[0]) } @@ -153,19 +142,15 @@ func TestMapGetKey(t *testing.T) { } func TestMapSerialization(t *testing.T) { - m := New() + m := New[string, float64]() m.Put("a", 1.0) m.Put("b", 2.0) m.Put("c", 3.0) var err error assert := func() { - if actualValue, expectedValue := m.Keys(), []interface{}{"a", "b", "c"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{1.0, 2.0, 3.0}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []string{"a", "b", "c"}) + testutils.SameElements(t, m.Values(), []float64{1.0, 2.0, 3.0}) if actualValue, expectedValue := m.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -194,33 +179,14 @@ func TestMapSerialization(t *testing.T) { } func TestMapString(t *testing.T) { - c := New() + c := New[string, int]() c.Put("a", 1) if !strings.HasPrefix(c.String(), "HashBidiMap") { t.Errorf("String should start with container name") } } -func sameElements(a []interface{}, b []interface{}) bool { - if len(a) != len(b) { - return false - } - for _, av := range a { - found := false - for _, bv := range b { - if av == bv { - found = true - break - } - } - if !found { - return false - } - } - return true -} - -func benchmarkGet(b *testing.B, m *Map, size int) { +func benchmarkGet(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Get(n) @@ -228,7 +194,7 @@ func benchmarkGet(b *testing.B, m *Map, size int) { } } -func benchmarkPut(b *testing.B, m *Map, size int) { +func benchmarkPut(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Put(n, n) @@ -236,7 +202,7 @@ func benchmarkPut(b *testing.B, m *Map, size int) { } } -func benchmarkRemove(b *testing.B, m *Map, size int) { +func benchmarkRemove(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Remove(n) @@ -247,7 +213,7 @@ func benchmarkRemove(b *testing.B, m *Map, size int) { func BenchmarkHashBidiMapGet100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -258,7 +224,7 @@ func BenchmarkHashBidiMapGet100(b *testing.B) { func BenchmarkHashBidiMapGet1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -269,7 +235,7 @@ func BenchmarkHashBidiMapGet1000(b *testing.B) { func BenchmarkHashBidiMapGet10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -280,7 +246,7 @@ func BenchmarkHashBidiMapGet10000(b *testing.B) { func BenchmarkHashBidiMapGet100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -291,7 +257,7 @@ func BenchmarkHashBidiMapGet100000(b *testing.B) { func BenchmarkHashBidiMapPut100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() b.StartTimer() benchmarkPut(b, m, size) } @@ -299,7 +265,7 @@ func BenchmarkHashBidiMapPut100(b *testing.B) { func BenchmarkHashBidiMapPut1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -310,7 +276,7 @@ func BenchmarkHashBidiMapPut1000(b *testing.B) { func BenchmarkHashBidiMapPut10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -321,7 +287,7 @@ func BenchmarkHashBidiMapPut10000(b *testing.B) { func BenchmarkHashBidiMapPut100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -332,7 +298,7 @@ func BenchmarkHashBidiMapPut100000(b *testing.B) { func BenchmarkHashBidiMapRemove100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -343,7 +309,7 @@ func BenchmarkHashBidiMapRemove100(b *testing.B) { func BenchmarkHashBidiMapRemove1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -354,7 +320,7 @@ func BenchmarkHashBidiMapRemove1000(b *testing.B) { func BenchmarkHashBidiMapRemove10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -365,7 +331,7 @@ func BenchmarkHashBidiMapRemove10000(b *testing.B) { func BenchmarkHashBidiMapRemove100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } diff --git a/maps/hashbidimap/serialization.go b/maps/hashbidimap/serialization.go index dfae0430..9decc35a 100644 --- a/maps/hashbidimap/serialization.go +++ b/maps/hashbidimap/serialization.go @@ -6,37 +6,41 @@ package hashbidimap import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Map)(nil) -var _ containers.JSONDeserializer = (*Map)(nil) +var _ containers.JSONSerializer = (*Map[string, int])(nil) +var _ containers.JSONDeserializer = (*Map[string, int])(nil) // ToJSON outputs the JSON representation of the map. -func (m *Map) ToJSON() ([]byte, error) { +func (m *Map[K, V]) ToJSON() ([]byte, error) { return m.forwardMap.ToJSON() } // FromJSON populates the map from the input JSON representation. -func (m *Map) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (m *Map[K, V]) FromJSON(data []byte) error { + var elements map[K]V err := json.Unmarshal(data, &elements) - if err == nil { - m.Clear() - for key, value := range elements { - m.Put(key, value) - } + if err != nil { + return err + } + + m.Clear() + for k, v := range elements { + m.Put(k, v) } - return err + + return nil } // UnmarshalJSON @implements json.Unmarshaler -func (m *Map) UnmarshalJSON(bytes []byte) error { +func (m *Map[K, V]) UnmarshalJSON(bytes []byte) error { return m.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (m *Map) MarshalJSON() ([]byte, error) { +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { return m.ToJSON() } diff --git a/maps/hashmap/hashmap.go b/maps/hashmap/hashmap.go index e945c39f..e2bbbf52 100644 --- a/maps/hashmap/hashmap.go +++ b/maps/hashmap/hashmap.go @@ -13,52 +13,53 @@ package hashmap import ( "fmt" - "github.com/emirpasic/gods/maps" + + "github.com/emirpasic/gods/v2/maps" ) // Assert Map implementation -var _ maps.Map = (*Map)(nil) +var _ maps.Map[string, int] = (*Map[string, int])(nil) // Map holds the elements in go's native map -type Map struct { - m map[interface{}]interface{} +type Map[K comparable, V any] struct { + m map[K]V } // New instantiates a hash map. -func New() *Map { - return &Map{m: make(map[interface{}]interface{})} +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{m: make(map[K]V)} } // Put inserts element into the map. -func (m *Map) Put(key interface{}, value interface{}) { +func (m *Map[K, V]) Put(key K, value V) { m.m[key] = value } // Get searches the element in the map by key and returns its value or nil if key is not found in map. // Second return parameter is true if key was found, otherwise false. -func (m *Map) Get(key interface{}) (value interface{}, found bool) { +func (m *Map[K, V]) Get(key K) (value V, found bool) { value, found = m.m[key] return } // Remove removes the element from the map by key. -func (m *Map) Remove(key interface{}) { +func (m *Map[K, V]) Remove(key K) { delete(m.m, key) } // Empty returns true if map does not contain any elements -func (m *Map) Empty() bool { +func (m *Map[K, V]) Empty() bool { return m.Size() == 0 } // Size returns number of elements in the map. -func (m *Map) Size() int { +func (m *Map[K, V]) Size() int { return len(m.m) } // Keys returns all keys (random order). -func (m *Map) Keys() []interface{} { - keys := make([]interface{}, m.Size()) +func (m *Map[K, V]) Keys() []K { + keys := make([]K, m.Size()) count := 0 for key := range m.m { keys[count] = key @@ -68,8 +69,8 @@ func (m *Map) Keys() []interface{} { } // Values returns all values (random order). -func (m *Map) Values() []interface{} { - values := make([]interface{}, m.Size()) +func (m *Map[K, V]) Values() []V { + values := make([]V, m.Size()) count := 0 for _, value := range m.m { values[count] = value @@ -79,12 +80,12 @@ func (m *Map) Values() []interface{} { } // Clear removes all elements from the map. -func (m *Map) Clear() { - m.m = make(map[interface{}]interface{}) +func (m *Map[K, V]) Clear() { + clear(m.m) } // String returns a string representation of container -func (m *Map) String() string { +func (m *Map[K, V]) String() string { str := "HashMap\n" str += fmt.Sprintf("%v", m.m) return str diff --git a/maps/hashmap/hashmap_test.go b/maps/hashmap/hashmap_test.go index 91acca8d..5e650682 100644 --- a/maps/hashmap/hashmap_test.go +++ b/maps/hashmap/hashmap_test.go @@ -6,13 +6,14 @@ package hashmap import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestMapPut(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -25,12 +26,8 @@ func TestMapPut(t *testing.T) { if actualValue := m.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4, 5, 6, 7}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d", "e", "f", "g"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4, 5, 6, 7}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}) // key,expectedValue,expectedFound tests1 := [][]interface{}{ @@ -41,12 +38,12 @@ func TestMapPut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -54,7 +51,7 @@ func TestMapPut(t *testing.T) { } func TestMapRemove(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -70,13 +67,9 @@ func TestMapRemove(t *testing.T) { m.Remove(8) m.Remove(5) - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d"}) - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } if actualValue := m.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } @@ -86,14 +79,14 @@ func TestMapRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -106,12 +99,8 @@ func TestMapRemove(t *testing.T) { m.Remove(2) m.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", m.Keys()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s", m.Values()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), nil) + testutils.SameElements(t, m.Values(), nil) if actualValue := m.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) } @@ -121,19 +110,15 @@ func TestMapRemove(t *testing.T) { } func TestMapSerialization(t *testing.T) { - m := New() + m := New[string, float64]() m.Put("a", 1.0) m.Put("b", 2.0) m.Put("c", 3.0) var err error assert := func() { - if actualValue, expectedValue := m.Keys(), []interface{}{"a", "b", "c"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{1.0, 2.0, 3.0}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []string{"a", "b", "c"}) + testutils.SameElements(t, m.Values(), []float64{1.0, 2.0, 3.0}) if actualValue, expectedValue := m.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -162,33 +147,14 @@ func TestMapSerialization(t *testing.T) { } func TestMapString(t *testing.T) { - c := New() + c := New[string, int]() c.Put("a", 1) if !strings.HasPrefix(c.String(), "HashMap") { t.Errorf("String should start with container name") } } -func sameElements(a []interface{}, b []interface{}) bool { - if len(a) != len(b) { - return false - } - for _, av := range a { - found := false - for _, bv := range b { - if av == bv { - found = true - break - } - } - if !found { - return false - } - } - return true -} - -func benchmarkGet(b *testing.B, m *Map, size int) { +func benchmarkGet(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Get(n) @@ -196,15 +162,15 @@ func benchmarkGet(b *testing.B, m *Map, size int) { } } -func benchmarkPut(b *testing.B, m *Map, size int) { +func benchmarkPut(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } } } -func benchmarkRemove(b *testing.B, m *Map, size int) { +func benchmarkRemove(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Remove(n) @@ -215,9 +181,9 @@ func benchmarkRemove(b *testing.B, m *Map, size int) { func BenchmarkHashMapGet100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -226,9 +192,9 @@ func BenchmarkHashMapGet100(b *testing.B) { func BenchmarkHashMapGet1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -237,9 +203,9 @@ func BenchmarkHashMapGet1000(b *testing.B) { func BenchmarkHashMapGet10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -248,9 +214,9 @@ func BenchmarkHashMapGet10000(b *testing.B) { func BenchmarkHashMapGet100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -259,7 +225,7 @@ func BenchmarkHashMapGet100000(b *testing.B) { func BenchmarkHashMapPut100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() b.StartTimer() benchmarkPut(b, m, size) } @@ -267,9 +233,9 @@ func BenchmarkHashMapPut100(b *testing.B) { func BenchmarkHashMapPut1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -278,9 +244,9 @@ func BenchmarkHashMapPut1000(b *testing.B) { func BenchmarkHashMapPut10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -289,9 +255,9 @@ func BenchmarkHashMapPut10000(b *testing.B) { func BenchmarkHashMapPut100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -300,9 +266,9 @@ func BenchmarkHashMapPut100000(b *testing.B) { func BenchmarkHashMapRemove100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -311,9 +277,9 @@ func BenchmarkHashMapRemove100(b *testing.B) { func BenchmarkHashMapRemove1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -322,9 +288,9 @@ func BenchmarkHashMapRemove1000(b *testing.B) { func BenchmarkHashMapRemove10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -333,9 +299,9 @@ func BenchmarkHashMapRemove10000(b *testing.B) { func BenchmarkHashMapRemove100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) diff --git a/maps/hashmap/serialization.go b/maps/hashmap/serialization.go index a86fd864..f1624998 100644 --- a/maps/hashmap/serialization.go +++ b/maps/hashmap/serialization.go @@ -6,42 +6,30 @@ package hashmap import ( "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Map)(nil) -var _ containers.JSONDeserializer = (*Map)(nil) +var _ containers.JSONSerializer = (*Map[string, int])(nil) +var _ containers.JSONDeserializer = (*Map[string, int])(nil) // ToJSON outputs the JSON representation of the map. -func (m *Map) ToJSON() ([]byte, error) { - elements := make(map[string]interface{}) - for key, value := range m.m { - elements[utils.ToString(key)] = value - } - return json.Marshal(&elements) +func (m *Map[K, V]) ToJSON() ([]byte, error) { + return json.Marshal(m.m) } // FromJSON populates the map from the input JSON representation. -func (m *Map) FromJSON(data []byte) error { - elements := make(map[string]interface{}) - err := json.Unmarshal(data, &elements) - if err == nil { - m.Clear() - for key, value := range elements { - m.m[key] = value - } - } - return err +func (m *Map[K, V]) FromJSON(data []byte) error { + return json.Unmarshal(data, &m.m) } // UnmarshalJSON @implements json.Unmarshaler -func (m *Map) UnmarshalJSON(bytes []byte) error { +func (m *Map[K, V]) UnmarshalJSON(bytes []byte) error { return m.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (m *Map) MarshalJSON() ([]byte, error) { +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { return m.ToJSON() } diff --git a/maps/linkedhashmap/enumerable.go b/maps/linkedhashmap/enumerable.go index eafcaa5d..fc496e72 100644 --- a/maps/linkedhashmap/enumerable.go +++ b/maps/linkedhashmap/enumerable.go @@ -4,13 +4,13 @@ package linkedhashmap -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithKey = (*Map)(nil) +var _ containers.EnumerableWithKey[string, int] = (*Map[string, int])(nil) // Each calls the given function once for each element, passing that element's key and value. -func (m *Map) Each(f func(key interface{}, value interface{})) { +func (m *Map[K, V]) Each(f func(key K, value V)) { iterator := m.Iterator() for iterator.Next() { f(iterator.Key(), iterator.Value()) @@ -19,8 +19,8 @@ func (m *Map) Each(f func(key interface{}, value interface{})) { // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. -func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, interface{})) *Map { - newMap := New() +func (m *Map[K, V]) Map(f func(key1 K, value1 V) (K, V)) *Map[K, V] { + newMap := New[K, V]() iterator := m.Iterator() for iterator.Next() { key2, value2 := f(iterator.Key(), iterator.Value()) @@ -30,8 +30,8 @@ func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, int } // Select returns a new container containing all elements for which the given function returns a true value. -func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { - newMap := New() +func (m *Map[K, V]) Select(f func(key K, value V) bool) *Map[K, V] { + newMap := New[K, V]() iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -43,7 +43,7 @@ func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) Any(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -55,7 +55,7 @@ func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) All(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if !f(iterator.Key(), iterator.Value()) { @@ -68,12 +68,12 @@ func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (key,value) for which the function is true or nil,nil otherwise if no element // matches the criteria. -func (m *Map) Find(f func(key interface{}, value interface{}) bool) (interface{}, interface{}) { +func (m *Map[K, V]) Find(f func(key K, value V) bool) (k K, v V) { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { return iterator.Key(), iterator.Value() } } - return nil, nil + return k, v } diff --git a/maps/linkedhashmap/iterator.go b/maps/linkedhashmap/iterator.go index 4c141780..9ac39318 100644 --- a/maps/linkedhashmap/iterator.go +++ b/maps/linkedhashmap/iterator.go @@ -5,77 +5,78 @@ package linkedhashmap import ( - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/lists/doublylinkedlist" + "github.com/emirpasic/gods/v2/containers" + "github.com/emirpasic/gods/v2/lists/doublylinkedlist" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - iterator doublylinkedlist.Iterator - table map[interface{}]interface{} +type Iterator[K comparable, V any] struct { + iterator doublylinkedlist.Iterator[K] + table map[K]V } // Iterator returns a stateful iterator whose elements are key/value pairs. -func (m *Map) Iterator() Iterator { - return Iterator{ +func (m *Map[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{ iterator: m.ordering.Iterator(), - table: m.table} + table: m.table, + } } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { return iterator.iterator.Next() } // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { return iterator.iterator.Prev() } // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[K, V]) Value() V { key := iterator.iterator.Value() return iterator.table[key] } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() K { return iterator.iterator.Value() } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.iterator.End() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { return iterator.iterator.First() } // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { return iterator.iterator.Last() } @@ -83,7 +84,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -97,7 +98,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/maps/linkedhashmap/linkedhashmap.go b/maps/linkedhashmap/linkedhashmap.go index d625b3d7..bf5ef4c3 100644 --- a/maps/linkedhashmap/linkedhashmap.go +++ b/maps/linkedhashmap/linkedhashmap.go @@ -13,31 +13,32 @@ package linkedhashmap import ( "fmt" - "github.com/emirpasic/gods/lists/doublylinkedlist" - "github.com/emirpasic/gods/maps" "strings" + + "github.com/emirpasic/gods/v2/lists/doublylinkedlist" + "github.com/emirpasic/gods/v2/maps" ) // Assert Map implementation -var _ maps.Map = (*Map)(nil) +var _ maps.Map[string, int] = (*Map[string, int])(nil) // Map holds the elements in a regular hash table, and uses doubly-linked list to store key ordering. -type Map struct { - table map[interface{}]interface{} - ordering *doublylinkedlist.List +type Map[K comparable, V any] struct { + table map[K]V + ordering *doublylinkedlist.List[K] } // New instantiates a linked-hash-map. -func New() *Map { - return &Map{ - table: make(map[interface{}]interface{}), - ordering: doublylinkedlist.New(), +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + table: make(map[K]V), + ordering: doublylinkedlist.New[K](), } } // Put inserts key-value pair into the map. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Put(key interface{}, value interface{}) { +func (m *Map[K, V]) Put(key K, value V) { if _, contains := m.table[key]; !contains { m.ordering.Append(key) } @@ -47,15 +48,14 @@ func (m *Map) Put(key interface{}, value interface{}) { // Get searches the element in the map by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Get(key interface{}) (value interface{}, found bool) { - value = m.table[key] - found = value != nil - return +func (m *Map[K, V]) Get(key K) (value V, found bool) { + value, found = m.table[key] + return value, found } // Remove removes the element from the map by key. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Remove(key interface{}) { +func (m *Map[K, V]) Remove(key K) { if _, contains := m.table[key]; contains { delete(m.table, key) index := m.ordering.IndexOf(key) @@ -64,23 +64,23 @@ func (m *Map) Remove(key interface{}) { } // Empty returns true if map does not contain any elements -func (m *Map) Empty() bool { +func (m *Map[K, V]) Empty() bool { return m.Size() == 0 } // Size returns number of elements in the map. -func (m *Map) Size() int { +func (m *Map[K, V]) Size() int { return m.ordering.Size() } // Keys returns all keys in-order -func (m *Map) Keys() []interface{} { +func (m *Map[K, V]) Keys() []K { return m.ordering.Values() } // Values returns all values in-order based on the key. -func (m *Map) Values() []interface{} { - values := make([]interface{}, m.Size()) +func (m *Map[K, V]) Values() []V { + values := make([]V, m.Size()) count := 0 it := m.Iterator() for it.Next() { @@ -91,13 +91,13 @@ func (m *Map) Values() []interface{} { } // Clear removes all elements from the map. -func (m *Map) Clear() { - m.table = make(map[interface{}]interface{}) +func (m *Map[K, V]) Clear() { + clear(m.table) m.ordering.Clear() } // String returns a string representation of container -func (m *Map) String() string { +func (m *Map[K, V]) String() string { str := "LinkedHashMap\nmap[" it := m.Iterator() for it.Next() { diff --git a/maps/linkedhashmap/linkedhashmap_test.go b/maps/linkedhashmap/linkedhashmap_test.go index ca44792c..a8d597e8 100644 --- a/maps/linkedhashmap/linkedhashmap_test.go +++ b/maps/linkedhashmap/linkedhashmap_test.go @@ -6,13 +6,14 @@ package linkedhashmap import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestMapPut(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -25,12 +26,8 @@ func TestMapPut(t *testing.T) { if actualValue := m.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := m.Keys(), []interface{}{5, 6, 7, 3, 4, 1, 2}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{"e", "f", "g", "c", "d", "a", "b"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4, 5, 6, 7}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}) // key,expectedValue,expectedFound tests1 := [][]interface{}{ @@ -41,12 +38,12 @@ func TestMapPut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -54,7 +51,7 @@ func TestMapPut(t *testing.T) { } func TestMapRemove(t *testing.T) { - m := New() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -70,13 +67,9 @@ func TestMapRemove(t *testing.T) { m.Remove(8) m.Remove(5) - if actualValue, expectedValue := m.Keys(), []interface{}{3, 4, 1, 2}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d"}) - if actualValue, expectedValue := m.Values(), []interface{}{"c", "d", "a", "b"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } if actualValue := m.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } @@ -86,14 +79,14 @@ func TestMapRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -106,12 +99,8 @@ func TestMapRemove(t *testing.T) { m.Remove(2) m.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", m.Keys()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s", m.Values()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), nil) + testutils.SameElements(t, m.Values(), nil) if actualValue := m.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) } @@ -120,32 +109,13 @@ func TestMapRemove(t *testing.T) { } } -func sameElements(a []interface{}, b []interface{}) bool { - // If one is nil, the other must also be nil. - if (a == nil) != (b == nil) { - return false - } - - if len(a) != len(b) { - return false - } - - for i := range a { - if a[i] != b[i] { - return false - } - } - - return true -} - func TestMapEach(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 1) m.Put("a", 2) m.Put("b", 3) count := 0 - m.Each(func(key interface{}, value interface{}) { + m.Each(func(key string, value int) { count++ if actualValue, expectedValue := count, value; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -170,12 +140,12 @@ func TestMapEach(t *testing.T) { } func TestMapMap(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - mappedMap := m.Map(func(key1 interface{}, value1 interface{}) (key2 interface{}, value2 interface{}) { - return key1, value1.(int) * value1.(int) + mappedMap := m.Map(func(key1 string, value1 int) (key2 string, value2 int) { + return key1, value1 * value1 }) if actualValue, _ := mappedMap.Get("c"); actualValue != 9 { t.Errorf("Got %v expected %v", actualValue, "mapped: c") @@ -192,12 +162,12 @@ func TestMapMap(t *testing.T) { } func TestMapSelect(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("b", 1) m.Put("a", 2) - selectedMap := m.Select(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + selectedMap := m.Select(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if actualValue, _ := selectedMap.Get("b"); actualValue != 1 { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -211,18 +181,18 @@ func TestMapSelect(t *testing.T) { } func TestMapAny(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - any := m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 3 + any := m.Any(func(key string, value int) bool { + return value == 3 }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 4 + any = m.Any(func(key string, value int) bool { + return value == 4 }) if any != false { t.Errorf("Got %v expected %v", any, false) @@ -230,18 +200,18 @@ func TestMapAny(t *testing.T) { } func TestMapAll(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - all := m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "c" + all := m.All(func(key string, value int) bool { + return key >= "a" && key <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + all = m.All(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) @@ -249,38 +219,38 @@ func TestMapAll(t *testing.T) { } func TestMapFind(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - foundKey, foundValue := m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "c" + foundKey, foundValue := m.Find(func(key string, value int) bool { + return key == "c" }) if foundKey != "c" || foundValue != 3 { t.Errorf("Got %v -> %v expected %v -> %v", foundKey, foundValue, "c", 3) } - foundKey, foundValue = m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "x" + foundKey, foundValue = m.Find(func(key string, value int) bool { + return key == "x" }) - if foundKey != nil || foundValue != nil { - t.Errorf("Got %v at %v expected %v at %v", foundValue, foundKey, nil, nil) + if foundKey != "" || foundValue != 0 { + t.Errorf("Got %v at %v expected %v at %v", foundValue, foundKey, "", 0) } } func TestMapChaining(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - chainedMap := m.Select(func(key interface{}, value interface{}) bool { - return value.(int) > 1 - }).Map(func(key interface{}, value interface{}) (interface{}, interface{}) { - return key.(string) + key.(string), value.(int) * value.(int) + chainedMap := m.Select(func(key string, value int) bool { + return value > 1 + }).Map(func(key string, value int) (string, int) { + return key + key, value * value }) if actualValue := chainedMap.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, found := chainedMap.Get("aa"); actualValue != nil || found { + if actualValue, found := chainedMap.Get("aa"); actualValue != 0 || found { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue, found := chainedMap.Get("bb"); actualValue != 4 || !found { @@ -292,7 +262,7 @@ func TestMapChaining(t *testing.T) { } func TestMapIteratorNextOnEmpty(t *testing.T) { - m := New() + m := New[string, int]() it := m.Iterator() it = m.Iterator() for it.Next() { @@ -301,7 +271,7 @@ func TestMapIteratorNextOnEmpty(t *testing.T) { } func TestMapIteratorPrevOnEmpty(t *testing.T) { - m := New() + m := New[string, int]() it := m.Iterator() it = m.Iterator() for it.Prev() { @@ -310,7 +280,7 @@ func TestMapIteratorPrevOnEmpty(t *testing.T) { } func TestMapIteratorNext(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 1) m.Put("a", 2) m.Put("b", 3) @@ -347,7 +317,7 @@ func TestMapIteratorNext(t *testing.T) { } func TestMapIteratorPrev(t *testing.T) { - m := New() + m := New[string, int]() m.Put("c", 1) m.Put("a", 2) m.Put("b", 3) @@ -386,7 +356,7 @@ func TestMapIteratorPrev(t *testing.T) { } func TestMapIteratorBegin(t *testing.T) { - m := New() + m := New[int, string]() it := m.Iterator() it.Begin() m.Put(3, "c") @@ -402,7 +372,7 @@ func TestMapIteratorBegin(t *testing.T) { } func TestMapIteratorEnd(t *testing.T) { - m := New() + m := New[int, string]() it := m.Iterator() m.Put(3, "c") m.Put(1, "a") @@ -415,7 +385,7 @@ func TestMapIteratorEnd(t *testing.T) { } func TestMapIteratorFirst(t *testing.T) { - m := New() + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -429,7 +399,7 @@ func TestMapIteratorFirst(t *testing.T) { } func TestMapIteratorLast(t *testing.T) { - m := New() + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -444,13 +414,13 @@ func TestMapIteratorLast(t *testing.T) { func TestMapIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - m := New() + m := New[int, string]() it := m.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") @@ -459,7 +429,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (not found) { - m := New() + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -470,7 +440,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (found) { - m := New() + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -479,13 +449,13 @@ func TestMapIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -496,13 +466,13 @@ func TestMapIteratorNextTo(t *testing.T) { func TestMapIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - m := New() + m := New[int, string]() it := m.Iterator() it.End() for it.PrevTo(seek) { @@ -512,7 +482,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - m := New() + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -524,7 +494,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (found) { - m := New() + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -533,13 +503,13 @@ func TestMapIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -550,30 +520,36 @@ func TestMapIteratorPrevTo(t *testing.T) { func TestMapSerialization(t *testing.T) { for i := 0; i < 10; i++ { - original := New() + original := New[string, string]() original.Put("d", "4") original.Put("e", "5") original.Put("c", "3") original.Put("b", "2") original.Put("a", "1") - assertSerialization(original, "A", t) - serialized, err := original.ToJSON() if err != nil { t.Errorf("Got error %v", err) } - assertSerialization(original, "B", t) - deserialized := New() + deserialized := New[string, string]() err = deserialized.FromJSON(serialized) if err != nil { t.Errorf("Got error %v", err) } - assertSerialization(deserialized, "C", t) + + if original.Size() != deserialized.Size() { + t.Errorf("Got map of size %d, expected %d", original.Size(), deserialized.Size()) + } + original.Each(func(key string, expected string) { + actual, ok := deserialized.Get(key) + if !ok || actual != expected { + t.Errorf("Did not find expected value %v for key %v in deserialied map (got %v)", expected, key, actual) + } + }) } - m := New() + m := New[string, float64]() m.Put("a", 1.0) m.Put("b", 2.0) m.Put("c", 3.0) @@ -590,37 +566,14 @@ func TestMapSerialization(t *testing.T) { } func TestMapString(t *testing.T) { - c := New() + c := New[string, int]() c.Put("a", 1) if !strings.HasPrefix(c.String(), "LinkedHashMap") { t.Errorf("String should start with container name") } } -//noinspection GoBoolExpressions -func assertSerialization(m *Map, txt string, t *testing.T) { - if actualValue := m.Keys(); false || - actualValue[0].(string) != "d" || - actualValue[1].(string) != "e" || - actualValue[2].(string) != "c" || - actualValue[3].(string) != "b" || - actualValue[4].(string) != "a" { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[d,e,c,b,a]") - } - if actualValue := m.Values(); false || - actualValue[0].(string) != "4" || - actualValue[1].(string) != "5" || - actualValue[2].(string) != "3" || - actualValue[3].(string) != "2" || - actualValue[4].(string) != "1" { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[4,5,3,2,1]") - } - if actualValue, expectedValue := m.Size(), 5; actualValue != expectedValue { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, expectedValue) - } -} - -func benchmarkGet(b *testing.B, m *Map, size int) { +func benchmarkGet(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Get(n) @@ -628,15 +581,15 @@ func benchmarkGet(b *testing.B, m *Map, size int) { } } -func benchmarkPut(b *testing.B, m *Map, size int) { +func benchmarkPut(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } } } -func benchmarkRemove(b *testing.B, m *Map, size int) { +func benchmarkRemove(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Remove(n) @@ -647,9 +600,9 @@ func benchmarkRemove(b *testing.B, m *Map, size int) { func BenchmarkTreeMapGet100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -658,9 +611,9 @@ func BenchmarkTreeMapGet100(b *testing.B) { func BenchmarkTreeMapGet1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -669,9 +622,9 @@ func BenchmarkTreeMapGet1000(b *testing.B) { func BenchmarkTreeMapGet10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -680,9 +633,9 @@ func BenchmarkTreeMapGet10000(b *testing.B) { func BenchmarkTreeMapGet100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkGet(b, m, size) @@ -691,7 +644,7 @@ func BenchmarkTreeMapGet100000(b *testing.B) { func BenchmarkTreeMapPut100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() b.StartTimer() benchmarkPut(b, m, size) } @@ -699,9 +652,9 @@ func BenchmarkTreeMapPut100(b *testing.B) { func BenchmarkTreeMapPut1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -710,9 +663,9 @@ func BenchmarkTreeMapPut1000(b *testing.B) { func BenchmarkTreeMapPut10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -721,9 +674,9 @@ func BenchmarkTreeMapPut10000(b *testing.B) { func BenchmarkTreeMapPut100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkPut(b, m, size) @@ -732,9 +685,9 @@ func BenchmarkTreeMapPut100000(b *testing.B) { func BenchmarkTreeMapRemove100(b *testing.B) { b.StopTimer() size := 100 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -743,9 +696,9 @@ func BenchmarkTreeMapRemove100(b *testing.B) { func BenchmarkTreeMapRemove1000(b *testing.B) { b.StopTimer() size := 1000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -754,9 +707,9 @@ func BenchmarkTreeMapRemove1000(b *testing.B) { func BenchmarkTreeMapRemove10000(b *testing.B) { b.StopTimer() size := 10000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) @@ -765,9 +718,9 @@ func BenchmarkTreeMapRemove10000(b *testing.B) { func BenchmarkTreeMapRemove100000(b *testing.B) { b.StopTimer() size := 100000 - m := New() + m := New[int, int]() for n := 0; n < size; n++ { - m.Put(n, struct{}{}) + m.Put(n, n) } b.StartTimer() benchmarkRemove(b, m, size) diff --git a/maps/linkedhashmap/serialization.go b/maps/linkedhashmap/serialization.go index 9265f1db..6a83df95 100644 --- a/maps/linkedhashmap/serialization.go +++ b/maps/linkedhashmap/serialization.go @@ -6,17 +6,19 @@ package linkedhashmap import ( "bytes" + "cmp" "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + "slices" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Map)(nil) -var _ containers.JSONDeserializer = (*Map)(nil) +var _ containers.JSONSerializer = (*Map[string, int])(nil) +var _ containers.JSONDeserializer = (*Map[string, int])(nil) // ToJSON outputs the JSON representation of map. -func (m *Map) ToJSON() ([]byte, error) { +func (m *Map[K, V]) ToJSON() ([]byte, error) { var b []byte buf := bytes.NewBuffer(b) @@ -54,7 +56,7 @@ func (m *Map) ToJSON() ([]byte, error) { } // FromJSON populates map from the input JSON representation. -//func (m *Map) FromJSON(data []byte) error { +//func (m *Map[K, V]) FromJSON(data []byte) error { // elements := make(map[string]interface{}) // err := json.Unmarshal(data, &elements) // if err == nil { @@ -67,46 +69,42 @@ func (m *Map) ToJSON() ([]byte, error) { //} // FromJSON populates map from the input JSON representation. -func (m *Map) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (m *Map[K, V]) FromJSON(data []byte) error { + elements := make(map[K]V) err := json.Unmarshal(data, &elements) if err != nil { return err } - index := make(map[string]int) - var keys []interface{} + index := make(map[K]int) + var keys []K for key := range elements { keys = append(keys, key) esc, _ := json.Marshal(key) index[key] = bytes.Index(data, esc) } - byIndex := func(a, b interface{}) int { - key1 := a.(string) - key2 := b.(string) - index1 := index[key1] - index2 := index[key2] - return index1 - index2 + byIndex := func(key1, key2 K) int { + return cmp.Compare(index[key1], index[key2]) } - utils.Sort(keys, byIndex) + slices.SortFunc(keys, byIndex) m.Clear() for _, key := range keys { - m.Put(key, elements[key.(string)]) + m.Put(key, elements[key]) } return nil } // UnmarshalJSON @implements json.Unmarshaler -func (m *Map) UnmarshalJSON(bytes []byte) error { +func (m *Map[K, V]) UnmarshalJSON(bytes []byte) error { return m.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (m *Map) MarshalJSON() ([]byte, error) { +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { return m.ToJSON() } diff --git a/maps/maps.go b/maps/maps.go index cdce9f7b..b54f7dc5 100644 --- a/maps/maps.go +++ b/maps/maps.go @@ -15,16 +15,16 @@ // Reference: https://en.wikipedia.org/wiki/Associative_array package maps -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Map interface that all maps implement -type Map interface { - Put(key interface{}, value interface{}) - Get(key interface{}) (value interface{}, found bool) - Remove(key interface{}) - Keys() []interface{} +type Map[K comparable, V any] interface { + Put(key K, value V) + Get(key K) (value V, found bool) + Remove(key K) + Keys() []K - containers.Container + containers.Container[V] // Empty() bool // Size() int // Clear() @@ -33,8 +33,8 @@ type Map interface { } // BidiMap interface that all bidirectional maps implement (extends the Map interface) -type BidiMap interface { - GetKey(value interface{}) (key interface{}, found bool) +type BidiMap[K comparable, V comparable] interface { + GetKey(value V) (key K, found bool) - Map + Map[K, V] } diff --git a/maps/treebidimap/enumerable.go b/maps/treebidimap/enumerable.go index 8daef722..febeff46 100644 --- a/maps/treebidimap/enumerable.go +++ b/maps/treebidimap/enumerable.go @@ -4,13 +4,13 @@ package treebidimap -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithKey = (*Map)(nil) +var _ containers.EnumerableWithKey[string, int] = (*Map[string, int])(nil) // Each calls the given function once for each element, passing that element's key and value. -func (m *Map) Each(f func(key interface{}, value interface{})) { +func (m *Map[K, V]) Each(f func(key K, value V)) { iterator := m.Iterator() for iterator.Next() { f(iterator.Key(), iterator.Value()) @@ -19,8 +19,8 @@ func (m *Map) Each(f func(key interface{}, value interface{})) { // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. -func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, interface{})) *Map { - newMap := NewWith(m.keyComparator, m.valueComparator) +func (m *Map[K, V]) Map(f func(key1 K, value1 V) (K, V)) *Map[K, V] { + newMap := NewWith[K, V](m.forwardMap.Comparator, m.inverseMap.Comparator) iterator := m.Iterator() for iterator.Next() { key2, value2 := f(iterator.Key(), iterator.Value()) @@ -30,8 +30,8 @@ func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, int } // Select returns a new container containing all elements for which the given function returns a true value. -func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { - newMap := NewWith(m.keyComparator, m.valueComparator) +func (m *Map[K, V]) Select(f func(key K, value V) bool) *Map[K, V] { + newMap := NewWith[K, V](m.forwardMap.Comparator, m.inverseMap.Comparator) iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -43,7 +43,7 @@ func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) Any(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -55,7 +55,7 @@ func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) All(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if !f(iterator.Key(), iterator.Value()) { @@ -68,12 +68,12 @@ func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (key,value) for which the function is true or nil,nil otherwise if no element // matches the criteria. -func (m *Map) Find(f func(key interface{}, value interface{}) bool) (interface{}, interface{}) { +func (m *Map[K, V]) Find(f func(key K, value V) bool) (k K, v V) { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { return iterator.Key(), iterator.Value() } } - return nil, nil + return k, v } diff --git a/maps/treebidimap/iterator.go b/maps/treebidimap/iterator.go index 9961a110..bf86ea71 100644 --- a/maps/treebidimap/iterator.go +++ b/maps/treebidimap/iterator.go @@ -5,73 +5,73 @@ package treebidimap import ( - "github.com/emirpasic/gods/containers" - rbt "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/v2/containers" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - iterator rbt.Iterator +type Iterator[K comparable, V any] struct { + iterator *rbt.Iterator[K, V] } // Iterator returns a stateful iterator whose elements are key/value pairs. -func (m *Map) Iterator() Iterator { - return Iterator{iterator: m.forwardMap.Iterator()} +func (m *Map[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{iterator: m.forwardMap.Iterator()} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { return iterator.iterator.Next() } // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { return iterator.iterator.Prev() } // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { - return iterator.iterator.Value().(*data).value +func (iterator *Iterator[K, V]) Value() V { + return iterator.iterator.Value() } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() K { return iterator.iterator.Key() } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.iterator.End() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { return iterator.iterator.First() } // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { return iterator.iterator.Last() } @@ -79,7 +79,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -93,7 +93,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/maps/treebidimap/serialization.go b/maps/treebidimap/serialization.go index 2cccce64..b72907b9 100644 --- a/maps/treebidimap/serialization.go +++ b/maps/treebidimap/serialization.go @@ -6,43 +6,41 @@ package treebidimap import ( "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Map)(nil) -var _ containers.JSONDeserializer = (*Map)(nil) +var _ containers.JSONSerializer = (*Map[string, int])(nil) +var _ containers.JSONDeserializer = (*Map[string, int])(nil) // ToJSON outputs the JSON representation of the map. -func (m *Map) ToJSON() ([]byte, error) { - elements := make(map[string]interface{}) - it := m.Iterator() - for it.Next() { - elements[utils.ToString(it.Key())] = it.Value() - } - return json.Marshal(&elements) +func (m *Map[K, V]) ToJSON() ([]byte, error) { + return m.forwardMap.ToJSON() } // FromJSON populates the map from the input JSON representation. -func (m *Map) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (m *Map[K, V]) FromJSON(data []byte) error { + var elements map[K]V err := json.Unmarshal(data, &elements) - if err == nil { - m.Clear() - for key, value := range elements { - m.Put(key, value) - } + if err != nil { + return err } - return err + + m.Clear() + for key, value := range elements { + m.Put(key, value) + } + + return nil } // UnmarshalJSON @implements json.Unmarshaler -func (m *Map) UnmarshalJSON(bytes []byte) error { +func (m *Map[K, V]) UnmarshalJSON(bytes []byte) error { return m.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (m *Map) MarshalJSON() ([]byte, error) { +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { return m.ToJSON() } diff --git a/maps/treebidimap/treebidimap.go b/maps/treebidimap/treebidimap.go index 37af07e0..27d6c9ef 100644 --- a/maps/treebidimap/treebidimap.go +++ b/maps/treebidimap/treebidimap.go @@ -18,116 +18,101 @@ package treebidimap import ( + "cmp" "fmt" - "github.com/emirpasic/gods/maps" - "github.com/emirpasic/gods/trees/redblacktree" - "github.com/emirpasic/gods/utils" + "strings" + + "github.com/emirpasic/gods/v2/maps" + "github.com/emirpasic/gods/v2/trees/redblacktree" + "github.com/emirpasic/gods/v2/utils" ) // Assert Map implementation -var _ maps.BidiMap = (*Map)(nil) +var _ maps.BidiMap[string, int] = (*Map[string, int])(nil) // Map holds the elements in two red-black trees. -type Map struct { - forwardMap redblacktree.Tree - inverseMap redblacktree.Tree - keyComparator utils.Comparator - valueComparator utils.Comparator +type Map[K, V comparable] struct { + forwardMap redblacktree.Tree[K, V] + inverseMap redblacktree.Tree[V, K] } -type data struct { - key interface{} - value interface{} +// New instantiates a bidirectional map. +func New[K, V cmp.Ordered]() *Map[K, V] { + return &Map[K, V]{ + forwardMap: *redblacktree.New[K, V](), + inverseMap: *redblacktree.New[V, K](), + } } // NewWith instantiates a bidirectional map. -func NewWith(keyComparator utils.Comparator, valueComparator utils.Comparator) *Map { - return &Map{ - forwardMap: *redblacktree.NewWith(keyComparator), - inverseMap: *redblacktree.NewWith(valueComparator), - keyComparator: keyComparator, - valueComparator: valueComparator, +func NewWith[K, V comparable](keyComparator utils.Comparator[K], valueComparator utils.Comparator[V]) *Map[K, V] { + return &Map[K, V]{ + forwardMap: *redblacktree.NewWith[K, V](keyComparator), + inverseMap: *redblacktree.NewWith[V, K](valueComparator), } } -// NewWithIntComparators instantiates a bidirectional map with the IntComparator for key and value, i.e. keys and values are of type int. -func NewWithIntComparators() *Map { - return NewWith(utils.IntComparator, utils.IntComparator) -} - -// NewWithStringComparators instantiates a bidirectional map with the StringComparator for key and value, i.e. keys and values are of type string. -func NewWithStringComparators() *Map { - return NewWith(utils.StringComparator, utils.StringComparator) -} - // Put inserts element into the map. -func (m *Map) Put(key interface{}, value interface{}) { - if d, ok := m.forwardMap.Get(key); ok { - m.inverseMap.Remove(d.(*data).value) +func (m *Map[K, V]) Put(key K, value V) { + if v, ok := m.forwardMap.Get(key); ok { + m.inverseMap.Remove(v) } - if d, ok := m.inverseMap.Get(value); ok { - m.forwardMap.Remove(d.(*data).key) + if k, ok := m.inverseMap.Get(value); ok { + m.forwardMap.Remove(k) } - d := &data{key: key, value: value} - m.forwardMap.Put(key, d) - m.inverseMap.Put(value, d) + m.forwardMap.Put(key, value) + m.inverseMap.Put(value, key) } // Get searches the element in the map by key and returns its value or nil if key is not found in map. // Second return parameter is true if key was found, otherwise false. -func (m *Map) Get(key interface{}) (value interface{}, found bool) { - if d, ok := m.forwardMap.Get(key); ok { - return d.(*data).value, true - } - return nil, false +func (m *Map[K, V]) Get(key K) (value V, found bool) { + return m.forwardMap.Get(key) } // GetKey searches the element in the map by value and returns its key or nil if value is not found in map. // Second return parameter is true if value was found, otherwise false. -func (m *Map) GetKey(value interface{}) (key interface{}, found bool) { - if d, ok := m.inverseMap.Get(value); ok { - return d.(*data).key, true - } - return nil, false +func (m *Map[K, V]) GetKey(value V) (key K, found bool) { + return m.inverseMap.Get(value) } // Remove removes the element from the map by key. -func (m *Map) Remove(key interface{}) { - if d, found := m.forwardMap.Get(key); found { +func (m *Map[K, V]) Remove(key K) { + if v, found := m.forwardMap.Get(key); found { m.forwardMap.Remove(key) - m.inverseMap.Remove(d.(*data).value) + m.inverseMap.Remove(v) } } // Empty returns true if map does not contain any elements -func (m *Map) Empty() bool { +func (m *Map[K, V]) Empty() bool { return m.Size() == 0 } // Size returns number of elements in the map. -func (m *Map) Size() int { +func (m *Map[K, V]) Size() int { return m.forwardMap.Size() } // Keys returns all keys (ordered). -func (m *Map) Keys() []interface{} { +func (m *Map[K, V]) Keys() []K { return m.forwardMap.Keys() } // Values returns all values (ordered). -func (m *Map) Values() []interface{} { +func (m *Map[K, V]) Values() []V { return m.inverseMap.Keys() } // Clear removes all elements from the map. -func (m *Map) Clear() { +func (m *Map[K, V]) Clear() { m.forwardMap.Clear() m.inverseMap.Clear() } // String returns a string representation of container -func (m *Map) String() string { +func (m *Map[K, V]) String() string { str := "TreeBidiMap\nmap[" it := m.Iterator() for it.Next() { diff --git a/maps/treebidimap/treebidimap_test.go b/maps/treebidimap/treebidimap_test.go index 75296e80..8c579ca6 100644 --- a/maps/treebidimap/treebidimap_test.go +++ b/maps/treebidimap/treebidimap_test.go @@ -6,14 +6,14 @@ package treebidimap import ( "encoding/json" - "fmt" - "github.com/emirpasic/gods/utils" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestMapPut(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -26,12 +26,8 @@ func TestMapPut(t *testing.T) { if actualValue := m.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4, 5, 6, 7}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d", "e", "f", "g"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4, 5, 6, 7}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}) // key,expectedValue,expectedFound tests1 := [][]interface{}{ @@ -42,12 +38,12 @@ func TestMapPut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -55,7 +51,7 @@ func TestMapPut(t *testing.T) { } func TestMapRemove(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -71,13 +67,9 @@ func TestMapRemove(t *testing.T) { m.Remove(8) m.Remove(5) - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d"}) - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } if actualValue := m.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } @@ -87,14 +79,14 @@ func TestMapRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -107,12 +99,8 @@ func TestMapRemove(t *testing.T) { m.Remove(2) m.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", m.Keys()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s", m.Values()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), nil) + testutils.SameElements(t, m.Values(), nil) if actualValue := m.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) } @@ -122,7 +110,7 @@ func TestMapRemove(t *testing.T) { } func TestMapGetKey(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -141,12 +129,12 @@ func TestMapGetKey(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {nil, "x", false}, + {0, "x", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.GetKey(test[1]) + actualValue, actualFound := m.GetKey(test[1].(string)) if actualValue != test[0] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[0]) } @@ -173,12 +161,12 @@ func sameElements(a []interface{}, b []interface{}) bool { } func TestMapEach(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) count := 0 - m.Each(func(key interface{}, value interface{}) { + m.Each(func(key string, value int) { count++ if actualValue, expectedValue := count, value; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -203,12 +191,12 @@ func TestMapEach(t *testing.T) { } func TestMapMap(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - mappedMap := m.Map(func(key1 interface{}, value1 interface{}) (key2 interface{}, value2 interface{}) { - return key1, value1.(int) * value1.(int) + mappedMap := m.Map(func(key1 string, value1 int) (key2 string, value2 int) { + return key1, value1 * value1 }) if actualValue, _ := mappedMap.Get("a"); actualValue != 1 { t.Errorf("Got %v expected %v", actualValue, "mapped: a") @@ -225,12 +213,12 @@ func TestMapMap(t *testing.T) { } func TestMapSelect(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - selectedMap := m.Select(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + selectedMap := m.Select(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if actualValue, _ := selectedMap.Get("a"); actualValue != 1 { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -244,18 +232,18 @@ func TestMapSelect(t *testing.T) { } func TestMapAny(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - any := m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 3 + any := m.Any(func(key string, value int) bool { + return value == 3 }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 4 + any = m.Any(func(key string, value int) bool { + return value == 4 }) if any != false { t.Errorf("Got %v expected %v", any, false) @@ -263,18 +251,18 @@ func TestMapAny(t *testing.T) { } func TestMapAll(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - all := m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "c" + all := m.All(func(key string, value int) bool { + return key >= "a" && key <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + all = m.All(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) @@ -282,38 +270,38 @@ func TestMapAll(t *testing.T) { } func TestMapFind(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - foundKey, foundValue := m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "c" + foundKey, foundValue := m.Find(func(key string, value int) bool { + return key == "c" }) if foundKey != "c" || foundValue != 3 { t.Errorf("Got %v -> %v expected %v -> %v", foundKey, foundValue, "c", 3) } - foundKey, foundValue = m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "x" + foundKey, foundValue = m.Find(func(key string, value int) bool { + return key == "x" }) - if foundKey != nil || foundValue != nil { + if foundKey != "" || foundValue != 0 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundKey, nil, nil) } } func TestMapChaining(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - chainedMap := m.Select(func(key interface{}, value interface{}) bool { - return value.(int) > 1 - }).Map(func(key interface{}, value interface{}) (interface{}, interface{}) { - return key.(string) + key.(string), value.(int) * value.(int) + chainedMap := m.Select(func(key string, value int) bool { + return value > 1 + }).Map(func(key string, value int) (string, int) { + return key + key, value * value }) if actualValue := chainedMap.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, found := chainedMap.Get("aa"); actualValue != nil || found { + if actualValue, found := chainedMap.Get("aa"); actualValue != 0 || found { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue, found := chainedMap.Get("bb"); actualValue != 4 || !found { @@ -325,7 +313,7 @@ func TestMapChaining(t *testing.T) { } func TestMapIteratorNextOnEmpty(t *testing.T) { - m := NewWithStringComparators() + m := New[string, string]() it := m.Iterator() it = m.Iterator() for it.Next() { @@ -334,7 +322,7 @@ func TestMapIteratorNextOnEmpty(t *testing.T) { } func TestMapIteratorPrevOnEmpty(t *testing.T) { - m := NewWithStringComparators() + m := New[string, string]() it := m.Iterator() it = m.Iterator() for it.Prev() { @@ -343,7 +331,7 @@ func TestMapIteratorPrevOnEmpty(t *testing.T) { } func TestMapIteratorNext(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) @@ -380,7 +368,7 @@ func TestMapIteratorNext(t *testing.T) { } func TestMapIteratorPrev(t *testing.T) { - m := NewWith(utils.StringComparator, utils.IntComparator) + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) @@ -419,7 +407,7 @@ func TestMapIteratorPrev(t *testing.T) { } func TestMapIteratorBegin(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() it := m.Iterator() it.Begin() m.Put(3, "c") @@ -435,7 +423,7 @@ func TestMapIteratorBegin(t *testing.T) { } func TestMapIteratorEnd(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() it := m.Iterator() m.Put(3, "c") m.Put(1, "a") @@ -448,7 +436,7 @@ func TestMapIteratorEnd(t *testing.T) { } func TestMapIteratorFirst(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -462,7 +450,7 @@ func TestMapIteratorFirst(t *testing.T) { } func TestMapIteratorLast(t *testing.T) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -477,13 +465,13 @@ func TestMapIteratorLast(t *testing.T) { func TestMapIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() it := m.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") @@ -492,7 +480,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (not found) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -503,7 +491,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (found) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -512,13 +500,13 @@ func TestMapIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -529,13 +517,13 @@ func TestMapIteratorNextTo(t *testing.T) { func TestMapIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() it := m.Iterator() it.End() for it.PrevTo(seek) { @@ -545,7 +533,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -557,7 +545,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (found) { - m := NewWith(utils.IntComparator, utils.StringComparator) + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -566,13 +554,13 @@ func TestMapIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -583,30 +571,36 @@ func TestMapIteratorPrevTo(t *testing.T) { func TestMapSerialization(t *testing.T) { for i := 0; i < 10; i++ { - original := NewWith(utils.StringComparator, utils.StringComparator) + original := New[string, string]() original.Put("d", "4") original.Put("e", "5") original.Put("c", "3") original.Put("b", "2") original.Put("a", "1") - assertSerialization(original, "A", t) - serialized, err := original.ToJSON() if err != nil { t.Errorf("Got error %v", err) } - assertSerialization(original, "B", t) - deserialized := NewWith(utils.StringComparator, utils.StringComparator) + deserialized := New[string, string]() err = deserialized.FromJSON(serialized) if err != nil { t.Errorf("Got error %v", err) } - assertSerialization(deserialized, "C", t) + + if original.Size() != deserialized.Size() { + t.Errorf("Got map of size %d, expected %d", original.Size(), deserialized.Size()) + } + original.Each(func(key string, expected string) { + actual, ok := deserialized.Get(key) + if !ok || actual != expected { + t.Errorf("Did not find expected value %v for key %v in deserialied map (got %q)", expected, key, actual) + } + }) } - m := NewWith(utils.StringComparator, utils.Float64Comparator) + m := New[string, float64]() m.Put("a", 1.0) m.Put("b", 2.0) m.Put("c", 3.0) @@ -623,37 +617,14 @@ func TestMapSerialization(t *testing.T) { } func TestMapString(t *testing.T) { - c := NewWithStringComparators() + c := New[string, string]() c.Put("a", "a") if !strings.HasPrefix(c.String(), "TreeBidiMap") { t.Errorf("String should start with container name") } } -//noinspection GoBoolExpressions -func assertSerialization(m *Map, txt string, t *testing.T) { - if actualValue := m.Keys(); false || - actualValue[0].(string) != "a" || - actualValue[1].(string) != "b" || - actualValue[2].(string) != "c" || - actualValue[3].(string) != "d" || - actualValue[4].(string) != "e" { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[a,b,c,d,e]") - } - if actualValue := m.Values(); false || - actualValue[0].(string) != "1" || - actualValue[1].(string) != "2" || - actualValue[2].(string) != "3" || - actualValue[3].(string) != "4" || - actualValue[4].(string) != "5" { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[1,2,3,4,5]") - } - if actualValue, expectedValue := m.Size(), 5; actualValue != expectedValue { - t.Errorf("[%s] Got %v expected %v", txt, actualValue, expectedValue) - } -} - -func benchmarkGet(b *testing.B, m *Map, size int) { +func benchmarkGet(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Get(n) @@ -661,7 +632,7 @@ func benchmarkGet(b *testing.B, m *Map, size int) { } } -func benchmarkPut(b *testing.B, m *Map, size int) { +func benchmarkPut(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Put(n, n) @@ -669,7 +640,7 @@ func benchmarkPut(b *testing.B, m *Map, size int) { } } -func benchmarkRemove(b *testing.B, m *Map, size int) { +func benchmarkRemove(b *testing.B, m *Map[int, int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Remove(n) @@ -680,7 +651,7 @@ func benchmarkRemove(b *testing.B, m *Map, size int) { func BenchmarkTreeBidiMapGet100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -691,7 +662,7 @@ func BenchmarkTreeBidiMapGet100(b *testing.B) { func BenchmarkTreeBidiMapGet1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -702,7 +673,7 @@ func BenchmarkTreeBidiMapGet1000(b *testing.B) { func BenchmarkTreeBidiMapGet10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -713,7 +684,7 @@ func BenchmarkTreeBidiMapGet10000(b *testing.B) { func BenchmarkTreeBidiMapGet100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -724,7 +695,7 @@ func BenchmarkTreeBidiMapGet100000(b *testing.B) { func BenchmarkTreeBidiMapPut100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparators() + m := New[int, int]() b.StartTimer() benchmarkPut(b, m, size) } @@ -732,7 +703,7 @@ func BenchmarkTreeBidiMapPut100(b *testing.B) { func BenchmarkTreeBidiMapPut1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -743,7 +714,7 @@ func BenchmarkTreeBidiMapPut1000(b *testing.B) { func BenchmarkTreeBidiMapPut10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -754,7 +725,7 @@ func BenchmarkTreeBidiMapPut10000(b *testing.B) { func BenchmarkTreeBidiMapPut100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -765,7 +736,7 @@ func BenchmarkTreeBidiMapPut100000(b *testing.B) { func BenchmarkTreeBidiMapRemove100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -776,7 +747,7 @@ func BenchmarkTreeBidiMapRemove100(b *testing.B) { func BenchmarkTreeBidiMapRemove1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -787,7 +758,7 @@ func BenchmarkTreeBidiMapRemove1000(b *testing.B) { func BenchmarkTreeBidiMapRemove10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } @@ -798,7 +769,7 @@ func BenchmarkTreeBidiMapRemove10000(b *testing.B) { func BenchmarkTreeBidiMapRemove100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparators() + m := New[int, int]() for n := 0; n < size; n++ { m.Put(n, n) } diff --git a/maps/treemap/enumerable.go b/maps/treemap/enumerable.go index 34b3704d..7b0df161 100644 --- a/maps/treemap/enumerable.go +++ b/maps/treemap/enumerable.go @@ -5,15 +5,15 @@ package treemap import ( - "github.com/emirpasic/gods/containers" - rbt "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/v2/containers" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // Assert Enumerable implementation -var _ containers.EnumerableWithKey = (*Map)(nil) +var _ containers.EnumerableWithKey[string, int] = (*Map[string, int])(nil) // Each calls the given function once for each element, passing that element's key and value. -func (m *Map) Each(f func(key interface{}, value interface{})) { +func (m *Map[K, V]) Each(f func(key K, value V)) { iterator := m.Iterator() for iterator.Next() { f(iterator.Key(), iterator.Value()) @@ -22,8 +22,8 @@ func (m *Map) Each(f func(key interface{}, value interface{})) { // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. -func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, interface{})) *Map { - newMap := &Map{tree: rbt.NewWith(m.tree.Comparator)} +func (m *Map[K, V]) Map(f func(key1 K, value1 V) (K, V)) *Map[K, V] { + newMap := &Map[K, V]{tree: rbt.NewWith[K, V](m.tree.Comparator)} iterator := m.Iterator() for iterator.Next() { key2, value2 := f(iterator.Key(), iterator.Value()) @@ -33,8 +33,8 @@ func (m *Map) Map(f func(key1 interface{}, value1 interface{}) (interface{}, int } // Select returns a new container containing all elements for which the given function returns a true value. -func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { - newMap := &Map{tree: rbt.NewWith(m.tree.Comparator)} +func (m *Map[K, V]) Select(f func(key K, value V) bool) *Map[K, V] { + newMap := &Map[K, V]{tree: rbt.NewWith[K, V](m.tree.Comparator)} iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -46,7 +46,7 @@ func (m *Map) Select(f func(key interface{}, value interface{}) bool) *Map { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) Any(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { @@ -58,7 +58,7 @@ func (m *Map) Any(f func(key interface{}, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { +func (m *Map[K, V]) All(f func(key K, value V) bool) bool { iterator := m.Iterator() for iterator.Next() { if !f(iterator.Key(), iterator.Value()) { @@ -71,12 +71,12 @@ func (m *Map) All(f func(key interface{}, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (key,value) for which the function is true or nil,nil otherwise if no element // matches the criteria. -func (m *Map) Find(f func(key interface{}, value interface{}) bool) (interface{}, interface{}) { +func (m *Map[K, V]) Find(f func(key K, value V) bool) (k K, v V) { iterator := m.Iterator() for iterator.Next() { if f(iterator.Key(), iterator.Value()) { return iterator.Key(), iterator.Value() } } - return nil, nil + return k, v } diff --git a/maps/treemap/iterator.go b/maps/treemap/iterator.go index becb56db..9bdf91b5 100644 --- a/maps/treemap/iterator.go +++ b/maps/treemap/iterator.go @@ -5,73 +5,73 @@ package treemap import ( - "github.com/emirpasic/gods/containers" - rbt "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/v2/containers" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - iterator rbt.Iterator +type Iterator[K comparable, V any] struct { + iterator *rbt.Iterator[K, V] } // Iterator returns a stateful iterator whose elements are key/value pairs. -func (m *Map) Iterator() Iterator { - return Iterator{iterator: m.tree.Iterator()} +func (m *Map[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{iterator: m.tree.Iterator()} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { return iterator.iterator.Next() } // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { return iterator.iterator.Prev() } // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[K, V]) Value() V { return iterator.iterator.Value() } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() K { return iterator.iterator.Key() } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.iterator.End() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { return iterator.iterator.First() } // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { return iterator.iterator.Last() } @@ -79,7 +79,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -93,7 +93,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/maps/treemap/serialization.go b/maps/treemap/serialization.go index 415a77dd..c63c4972 100644 --- a/maps/treemap/serialization.go +++ b/maps/treemap/serialization.go @@ -5,29 +5,29 @@ package treemap import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Map)(nil) -var _ containers.JSONDeserializer = (*Map)(nil) +var _ containers.JSONSerializer = (*Map[string, int])(nil) +var _ containers.JSONDeserializer = (*Map[string, int])(nil) // ToJSON outputs the JSON representation of the map. -func (m *Map) ToJSON() ([]byte, error) { +func (m *Map[K, V]) ToJSON() ([]byte, error) { return m.tree.ToJSON() } // FromJSON populates the map from the input JSON representation. -func (m *Map) FromJSON(data []byte) error { +func (m *Map[K, V]) FromJSON(data []byte) error { return m.tree.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (m *Map) UnmarshalJSON(bytes []byte) error { +func (m *Map[K, V]) UnmarshalJSON(bytes []byte) error { return m.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (m *Map) MarshalJSON() ([]byte, error) { +func (m *Map[K, V]) MarshalJSON() ([]byte, error) { return m.ToJSON() } diff --git a/maps/treemap/treemap.go b/maps/treemap/treemap.go index a77d16d8..f3766622 100644 --- a/maps/treemap/treemap.go +++ b/maps/treemap/treemap.go @@ -12,96 +12,93 @@ package treemap import ( + "cmp" "fmt" - "github.com/emirpasic/gods/maps" - rbt "github.com/emirpasic/gods/trees/redblacktree" - "github.com/emirpasic/gods/utils" "strings" + + "github.com/emirpasic/gods/v2/maps" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" + "github.com/emirpasic/gods/v2/utils" ) // Assert Map implementation -var _ maps.Map = (*Map)(nil) +var _ maps.Map[string, int] = (*Map[string, int])(nil) // Map holds the elements in a red-black tree -type Map struct { - tree *rbt.Tree +type Map[K comparable, V any] struct { + tree *rbt.Tree[K, V] } -// NewWith instantiates a tree map with the custom comparator. -func NewWith(comparator utils.Comparator) *Map { - return &Map{tree: rbt.NewWith(comparator)} +// New instantiates a tree map with the built-in comparator for K +func New[K cmp.Ordered, V any]() *Map[K, V] { + return &Map[K, V]{tree: rbt.New[K, V]()} } -// NewWithIntComparator instantiates a tree map with the IntComparator, i.e. keys are of type int. -func NewWithIntComparator() *Map { - return &Map{tree: rbt.NewWithIntComparator()} -} - -// NewWithStringComparator instantiates a tree map with the StringComparator, i.e. keys are of type string. -func NewWithStringComparator() *Map { - return &Map{tree: rbt.NewWithStringComparator()} +// NewWith instantiates a tree map with the custom comparator. +func NewWith[K comparable, V any](comparator utils.Comparator[K]) *Map[K, V] { + return &Map[K, V]{tree: rbt.NewWith[K, V](comparator)} } // Put inserts key-value pair into the map. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Put(key interface{}, value interface{}) { +func (m *Map[K, V]) Put(key K, value V) { m.tree.Put(key, value) } // Get searches the element in the map by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Get(key interface{}) (value interface{}, found bool) { +func (m *Map[K, V]) Get(key K) (value V, found bool) { return m.tree.Get(key) } // Remove removes the element from the map by key. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Remove(key interface{}) { +func (m *Map[K, V]) Remove(key K) { m.tree.Remove(key) } // Empty returns true if map does not contain any elements -func (m *Map) Empty() bool { +func (m *Map[K, V]) Empty() bool { return m.tree.Empty() } // Size returns number of elements in the map. -func (m *Map) Size() int { +func (m *Map[K, V]) Size() int { return m.tree.Size() } // Keys returns all keys in-order -func (m *Map) Keys() []interface{} { +func (m *Map[K, V]) Keys() []K { return m.tree.Keys() } // Values returns all values in-order based on the key. -func (m *Map) Values() []interface{} { +func (m *Map[K, V]) Values() []V { return m.tree.Values() } // Clear removes all elements from the map. -func (m *Map) Clear() { +func (m *Map[K, V]) Clear() { m.tree.Clear() } // Min returns the minimum key and its value from the tree map. -// Returns nil, nil if map is empty. -func (m *Map) Min() (key interface{}, value interface{}) { +// Returns 0-value, 0-value, false if map is empty. +func (m *Map[K, V]) Min() (key K, value V, ok bool) { if node := m.tree.Left(); node != nil { - return node.Key, node.Value + return node.Key, node.Value, true } - return nil, nil + return key, value, false } // Max returns the maximum key and its value from the tree map. -// Returns nil, nil if map is empty. -func (m *Map) Max() (key interface{}, value interface{}) { +// Returns 0-value, 0-value, false if map is empty. +func (m *Map[K, V]) Max() (key K, value V, ok bool) { if node := m.tree.Right(); node != nil { - return node.Key, node.Value + return node.Key, node.Value, true } - return nil, nil + return key, value, false } // Floor finds the floor key-value pair for the input key. @@ -113,12 +110,12 @@ func (m *Map) Max() (key interface{}, value interface{}) { // all keys in the map are larger than the given key. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Floor(key interface{}) (foundKey interface{}, foundValue interface{}) { +func (m *Map[K, V]) Floor(key K) (foundKey K, foundValue V, ok bool) { node, found := m.tree.Floor(key) if found { - return node.Key, node.Value + return node.Key, node.Value, true } - return nil, nil + return foundKey, foundValue, false } // Ceiling finds the ceiling key-value pair for the input key. @@ -130,16 +127,16 @@ func (m *Map) Floor(key interface{}) (foundKey interface{}, foundValue interface // all keys in the map are smaller than the given key. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (m *Map) Ceiling(key interface{}) (foundKey interface{}, foundValue interface{}) { +func (m *Map[K, V]) Ceiling(key K) (foundKey K, foundValue V, ok bool) { node, found := m.tree.Ceiling(key) if found { - return node.Key, node.Value + return node.Key, node.Value, true } - return nil, nil + return foundKey, foundValue, false } // String returns a string representation of container -func (m *Map) String() string { +func (m *Map[K, V]) String() string { str := "TreeMap\nmap[" it := m.Iterator() for it.Next() { diff --git a/maps/treemap/treemap_test.go b/maps/treemap/treemap_test.go index f0c96baa..e608f0e6 100644 --- a/maps/treemap/treemap_test.go +++ b/maps/treemap/treemap_test.go @@ -6,14 +6,14 @@ package treemap import ( "encoding/json" - "fmt" - "github.com/emirpasic/gods/utils" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestMapPut(t *testing.T) { - m := NewWith(utils.IntComparator) + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -26,12 +26,8 @@ func TestMapPut(t *testing.T) { if actualValue := m.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4, 5, 6, 7}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d", "e", "f", "g"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4, 5, 6, 7}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}) // key,expectedValue,expectedFound tests1 := [][]interface{}{ @@ -42,12 +38,12 @@ func TestMapPut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -55,10 +51,10 @@ func TestMapPut(t *testing.T) { } func TestMapMin(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() - if k, v := m.Min(); k != nil || v != nil { - t.Errorf("Got %v->%v expected %v->%v", k, v, nil, nil) + if k, v, ok := m.Min(); k != 0 || v != "" || ok { + t.Errorf("Got %v->%v->%v expected %v->%v-%v", k, v, ok, 0, "", false) } m.Put(5, "e") @@ -70,21 +66,24 @@ func TestMapMin(t *testing.T) { m.Put(2, "b") m.Put(1, "a") //overwrite - actualKey, actualValue := m.Min() - expectedKey, expectedValue := 1, "a" + actualKey, actualValue, actualOk := m.Min() + expectedKey, expectedValue, expectedOk := 1, "a", true if actualKey != expectedKey { t.Errorf("Got %v expected %v", actualKey, expectedKey) } if actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } + if actualOk != expectedOk { + t.Errorf("Got %v expected %v", actualOk, expectedOk) + } } func TestMapMax(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() - if k, v := m.Max(); k != nil || v != nil { - t.Errorf("Got %v->%v expected %v->%v", k, v, nil, nil) + if k, v, ok := m.Max(); k != 0 || v != "" || ok { + t.Errorf("Got %v->%v->%v expected %v->%v-%v", k, v, ok, 0, "", false) } m.Put(5, "e") @@ -96,18 +95,21 @@ func TestMapMax(t *testing.T) { m.Put(2, "b") m.Put(1, "a") //overwrite - actualKey, actualValue := m.Max() - expectedKey, expectedValue := 7, "g" + actualKey, actualValue, actualOk := m.Max() + expectedKey, expectedValue, expectedOk := 7, "g", true if actualKey != expectedKey { t.Errorf("Got %v expected %v", actualKey, expectedKey) } if actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } + if actualOk != expectedOk { + t.Errorf("Got %v expected %v", actualOk, expectedOk) + } } func TestMapClear(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -122,7 +124,7 @@ func TestMapClear(t *testing.T) { } func TestMapRemove(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(5, "e") m.Put(6, "f") m.Put(7, "g") @@ -138,13 +140,9 @@ func TestMapRemove(t *testing.T) { m.Remove(8) m.Remove(5) - if actualValue, expectedValue := m.Keys(), []interface{}{1, 2, 3, 4}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), []int{1, 2, 3, 4}) + testutils.SameElements(t, m.Values(), []string{"a", "b", "c", "d"}) - if actualValue, expectedValue := m.Values(), []interface{}{"a", "b", "c", "d"}; !sameElements(actualValue, expectedValue) { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } if actualValue := m.Size(); actualValue != 4 { t.Errorf("Got %v expected %v", actualValue, 4) } @@ -154,14 +152,14 @@ func TestMapRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := m.Get(test[0]) + actualValue, actualFound := m.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -174,12 +172,8 @@ func TestMapRemove(t *testing.T) { m.Remove(2) m.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", m.Keys()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s", m.Values()), "[]"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, m.Keys(), nil) + testutils.SameElements(t, m.Values(), nil) if actualValue := m.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) } @@ -189,15 +183,15 @@ func TestMapRemove(t *testing.T) { } func TestMapFloor(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(7, "g") m.Put(3, "c") m.Put(1, "a") // key,expectedKey,expectedValue,expectedFound tests1 := [][]interface{}{ - {-1, nil, nil, false}, - {0, nil, nil, false}, + {-1, 0, "", false}, + {0, 0, "", false}, {1, 1, "a", true}, {2, 1, "a", true}, {3, 3, "c", true}, @@ -208,16 +202,15 @@ func TestMapFloor(t *testing.T) { for _, test := range tests1 { // retrievals - actualKey, actualValue := m.Floor(test[0]) - actualFound := actualKey != nil && actualValue != nil - if actualKey != test[1] || actualValue != test[2] || actualFound != test[3] { - t.Errorf("Got %v, %v, %v, expected %v, %v, %v", actualKey, actualValue, actualFound, test[1], test[2], test[3]) + actualKey, actualValue, actualOk := m.Floor(test[0].(int)) + if actualKey != test[1] || actualValue != test[2] || actualOk != test[3] { + t.Errorf("Got %v, %v, %v, expected %v, %v, %v", actualKey, actualValue, actualOk, test[1], test[2], test[3]) } } } func TestMapCeiling(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(7, "g") m.Put(3, "c") m.Put(1, "a") @@ -231,45 +224,25 @@ func TestMapCeiling(t *testing.T) { {3, 3, "c", true}, {4, 7, "g", true}, {7, 7, "g", true}, - {8, nil, nil, false}, + {8, 0, "", false}, } for _, test := range tests1 { // retrievals - actualKey, actualValue := m.Ceiling(test[0]) - actualFound := actualKey != nil && actualValue != nil - if actualKey != test[1] || actualValue != test[2] || actualFound != test[3] { - t.Errorf("Got %v, %v, %v, expected %v, %v, %v", actualKey, actualValue, actualFound, test[1], test[2], test[3]) - } - } -} - -func sameElements(a []interface{}, b []interface{}) bool { - if len(a) != len(b) { - return false - } - for _, av := range a { - found := false - for _, bv := range b { - if av == bv { - found = true - break - } - } - if !found { - return false + actualKey, actualValue, actualOk := m.Ceiling(test[0].(int)) + if actualKey != test[1] || actualValue != test[2] || actualOk != test[3] { + t.Errorf("Got %v, %v, %v, expected %v, %v, %v", actualKey, actualValue, actualOk, test[1], test[2], test[3]) } } - return true } func TestMapEach(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) count := 0 - m.Each(func(key interface{}, value interface{}) { + m.Each(func(key string, value int) { count++ if actualValue, expectedValue := count, value; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -294,12 +267,12 @@ func TestMapEach(t *testing.T) { } func TestMapMap(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - mappedMap := m.Map(func(key1 interface{}, value1 interface{}) (key2 interface{}, value2 interface{}) { - return key1, value1.(int) * value1.(int) + mappedMap := m.Map(func(key1 string, value1 int) (key2 string, value2 int) { + return key1, value1 * value1 }) if actualValue, _ := mappedMap.Get("a"); actualValue != 1 { t.Errorf("Got %v expected %v", actualValue, "mapped: a") @@ -316,12 +289,12 @@ func TestMapMap(t *testing.T) { } func TestMapSelect(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - selectedMap := m.Select(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + selectedMap := m.Select(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if actualValue, _ := selectedMap.Get("a"); actualValue != 1 { t.Errorf("Got %v expected %v", actualValue, "value: a") @@ -335,18 +308,18 @@ func TestMapSelect(t *testing.T) { } func TestMapAny(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - any := m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 3 + any := m.Any(func(key string, value int) bool { + return value == 3 }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = m.Any(func(key interface{}, value interface{}) bool { - return value.(int) == 4 + any = m.Any(func(key string, value int) bool { + return value == 4 }) if any != false { t.Errorf("Got %v expected %v", any, false) @@ -354,18 +327,18 @@ func TestMapAny(t *testing.T) { } func TestMapAll(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - all := m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "c" + all := m.All(func(key string, value int) bool { + return key >= "a" && key <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = m.All(func(key interface{}, value interface{}) bool { - return key.(string) >= "a" && key.(string) <= "b" + all = m.All(func(key string, value int) bool { + return key >= "a" && key <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) @@ -373,38 +346,38 @@ func TestMapAll(t *testing.T) { } func TestMapFind(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - foundKey, foundValue := m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "c" + foundKey, foundValue := m.Find(func(key string, value int) bool { + return key == "c" }) if foundKey != "c" || foundValue != 3 { t.Errorf("Got %v -> %v expected %v -> %v", foundKey, foundValue, "c", 3) } - foundKey, foundValue = m.Find(func(key interface{}, value interface{}) bool { - return key.(string) == "x" + foundKey, foundValue = m.Find(func(key string, value int) bool { + return key == "x" }) - if foundKey != nil || foundValue != nil { + if foundKey != "" || foundValue != 0 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundKey, nil, nil) } } func TestMapChaining(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) - chainedMap := m.Select(func(key interface{}, value interface{}) bool { - return value.(int) > 1 - }).Map(func(key interface{}, value interface{}) (interface{}, interface{}) { - return key.(string) + key.(string), value.(int) * value.(int) + chainedMap := m.Select(func(key string, value int) bool { + return value > 1 + }).Map(func(key string, value int) (string, int) { + return key + key, value * value }) if actualValue := chainedMap.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) } - if actualValue, found := chainedMap.Get("aa"); actualValue != nil || found { + if actualValue, found := chainedMap.Get("aa"); actualValue != 0 || found { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue, found := chainedMap.Get("bb"); actualValue != 4 || !found { @@ -416,7 +389,7 @@ func TestMapChaining(t *testing.T) { } func TestMapIteratorNextOnEmpty(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() it := m.Iterator() it = m.Iterator() for it.Next() { @@ -425,7 +398,7 @@ func TestMapIteratorNextOnEmpty(t *testing.T) { } func TestMapIteratorPrevOnEmpty(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() it := m.Iterator() it = m.Iterator() for it.Prev() { @@ -434,7 +407,7 @@ func TestMapIteratorPrevOnEmpty(t *testing.T) { } func TestMapIteratorNext(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) @@ -471,7 +444,7 @@ func TestMapIteratorNext(t *testing.T) { } func TestMapIteratorPrev(t *testing.T) { - m := NewWithStringComparator() + m := New[string, int]() m.Put("c", 3) m.Put("a", 1) m.Put("b", 2) @@ -510,7 +483,7 @@ func TestMapIteratorPrev(t *testing.T) { } func TestMapIteratorBegin(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() it := m.Iterator() it.Begin() m.Put(3, "c") @@ -526,7 +499,7 @@ func TestMapIteratorBegin(t *testing.T) { } func TestMapIteratorEnd(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() it := m.Iterator() m.Put(3, "c") m.Put(1, "a") @@ -539,7 +512,7 @@ func TestMapIteratorEnd(t *testing.T) { } func TestMapIteratorFirst(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -553,7 +526,7 @@ func TestMapIteratorFirst(t *testing.T) { } func TestMapIteratorLast(t *testing.T) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(3, "c") m.Put(1, "a") m.Put(2, "b") @@ -568,13 +541,13 @@ func TestMapIteratorLast(t *testing.T) { func TestMapIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - m := NewWithIntComparator() + m := New[int, string]() it := m.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") @@ -583,7 +556,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (not found) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -594,7 +567,7 @@ func TestMapIteratorNextTo(t *testing.T) { // NextTo (found) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -603,13 +576,13 @@ func TestMapIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -620,13 +593,13 @@ func TestMapIteratorNextTo(t *testing.T) { func TestMapIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - m := NewWithIntComparator() + m := New[int, string]() it := m.Iterator() it.End() for it.PrevTo(seek) { @@ -636,7 +609,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(0, "xx") m.Put(1, "yy") it := m.Iterator() @@ -648,7 +621,7 @@ func TestMapIteratorPrevTo(t *testing.T) { // PrevTo (found) { - m := NewWithIntComparator() + m := New[int, string]() m.Put(0, "aa") m.Put(1, "bb") m.Put(2, "cc") @@ -657,13 +630,13 @@ func TestMapIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty map") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -674,7 +647,7 @@ func TestMapIteratorPrevTo(t *testing.T) { func TestMapSerialization(t *testing.T) { for i := 0; i < 10; i++ { - original := NewWithStringComparator() + original := New[string, string]() original.Put("d", "4") original.Put("e", "5") original.Put("c", "3") @@ -689,7 +662,7 @@ func TestMapSerialization(t *testing.T) { } assertSerialization(original, "B", t) - deserialized := NewWithStringComparator() + deserialized := New[string, string]() err = deserialized.FromJSON(serialized) if err != nil { t.Errorf("Got error %v", err) @@ -697,7 +670,7 @@ func TestMapSerialization(t *testing.T) { assertSerialization(deserialized, "C", t) } - m := NewWithStringComparator() + m := New[string, float64]() m.Put("a", 1.0) m.Put("b", 2.0) m.Put("c", 3.0) @@ -714,29 +687,29 @@ func TestMapSerialization(t *testing.T) { } func TestMapString(t *testing.T) { - c := NewWithStringComparator() + c := New[string, int]() c.Put("a", 1) if !strings.HasPrefix(c.String(), "TreeMap") { t.Errorf("String should start with container name") } } -//noinspection GoBoolExpressions -func assertSerialization(m *Map, txt string, t *testing.T) { +// noinspection GoBoolExpressions +func assertSerialization(m *Map[string, string], txt string, t *testing.T) { if actualValue := m.Keys(); false || - actualValue[0].(string) != "a" || - actualValue[1].(string) != "b" || - actualValue[2].(string) != "c" || - actualValue[3].(string) != "d" || - actualValue[4].(string) != "e" { + actualValue[0] != "a" || + actualValue[1] != "b" || + actualValue[2] != "c" || + actualValue[3] != "d" || + actualValue[4] != "e" { t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[a,b,c,d,e]") } if actualValue := m.Values(); false || - actualValue[0].(string) != "1" || - actualValue[1].(string) != "2" || - actualValue[2].(string) != "3" || - actualValue[3].(string) != "4" || - actualValue[4].(string) != "5" { + actualValue[0] != "1" || + actualValue[1] != "2" || + actualValue[2] != "3" || + actualValue[3] != "4" || + actualValue[4] != "5" { t.Errorf("[%s] Got %v expected %v", txt, actualValue, "[1,2,3,4,5]") } if actualValue, expectedValue := m.Size(), 5; actualValue != expectedValue { @@ -744,7 +717,7 @@ func assertSerialization(m *Map, txt string, t *testing.T) { } } -func benchmarkGet(b *testing.B, m *Map, size int) { +func benchmarkGet(b *testing.B, m *Map[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Get(n) @@ -752,7 +725,7 @@ func benchmarkGet(b *testing.B, m *Map, size int) { } } -func benchmarkPut(b *testing.B, m *Map, size int) { +func benchmarkPut(b *testing.B, m *Map[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Put(n, struct{}{}) @@ -760,7 +733,7 @@ func benchmarkPut(b *testing.B, m *Map, size int) { } } -func benchmarkRemove(b *testing.B, m *Map, size int) { +func benchmarkRemove(b *testing.B, m *Map[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { m.Remove(n) @@ -771,7 +744,7 @@ func benchmarkRemove(b *testing.B, m *Map, size int) { func BenchmarkTreeMapGet100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -782,7 +755,7 @@ func BenchmarkTreeMapGet100(b *testing.B) { func BenchmarkTreeMapGet1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -793,7 +766,7 @@ func BenchmarkTreeMapGet1000(b *testing.B) { func BenchmarkTreeMapGet10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -804,7 +777,7 @@ func BenchmarkTreeMapGet10000(b *testing.B) { func BenchmarkTreeMapGet100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -815,7 +788,7 @@ func BenchmarkTreeMapGet100000(b *testing.B) { func BenchmarkTreeMapPut100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparator() + m := New[int, struct{}]() b.StartTimer() benchmarkPut(b, m, size) } @@ -823,7 +796,7 @@ func BenchmarkTreeMapPut100(b *testing.B) { func BenchmarkTreeMapPut1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -834,7 +807,7 @@ func BenchmarkTreeMapPut1000(b *testing.B) { func BenchmarkTreeMapPut10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -845,7 +818,7 @@ func BenchmarkTreeMapPut10000(b *testing.B) { func BenchmarkTreeMapPut100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -856,7 +829,7 @@ func BenchmarkTreeMapPut100000(b *testing.B) { func BenchmarkTreeMapRemove100(b *testing.B) { b.StopTimer() size := 100 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -867,7 +840,7 @@ func BenchmarkTreeMapRemove100(b *testing.B) { func BenchmarkTreeMapRemove1000(b *testing.B) { b.StopTimer() size := 1000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -878,7 +851,7 @@ func BenchmarkTreeMapRemove1000(b *testing.B) { func BenchmarkTreeMapRemove10000(b *testing.B) { b.StopTimer() size := 10000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } @@ -889,7 +862,7 @@ func BenchmarkTreeMapRemove10000(b *testing.B) { func BenchmarkTreeMapRemove100000(b *testing.B) { b.StopTimer() size := 100000 - m := NewWithIntComparator() + m := New[int, struct{}]() for n := 0; n < size; n++ { m.Put(n, struct{}{}) } diff --git a/queues/arrayqueue/arrayqueue.go b/queues/arrayqueue/arrayqueue.go index 8d480e90..6c57c722 100644 --- a/queues/arrayqueue/arrayqueue.go +++ b/queues/arrayqueue/arrayqueue.go @@ -13,31 +13,31 @@ import ( "fmt" "strings" - "github.com/emirpasic/gods/lists/arraylist" - "github.com/emirpasic/gods/queues" + "github.com/emirpasic/gods/v2/lists/arraylist" + "github.com/emirpasic/gods/v2/queues" ) // Assert Queue implementation -var _ queues.Queue = (*Queue)(nil) +var _ queues.Queue[int] = (*Queue[int])(nil) // Queue holds elements in an array-list -type Queue struct { - list *arraylist.List +type Queue[T comparable] struct { + list *arraylist.List[T] } // New instantiates a new empty queue -func New() *Queue { - return &Queue{list: arraylist.New()} +func New[T comparable]() *Queue[T] { + return &Queue[T]{list: arraylist.New[T]()} } // Enqueue adds a value to the end of the queue -func (queue *Queue) Enqueue(value interface{}) { +func (queue *Queue[T]) Enqueue(value T) { queue.list.Add(value) } // Dequeue removes first element of the queue and returns it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to dequeue. -func (queue *Queue) Dequeue() (value interface{}, ok bool) { +func (queue *Queue[T]) Dequeue() (value T, ok bool) { value, ok = queue.list.Get(0) if ok { queue.list.Remove(0) @@ -47,32 +47,32 @@ func (queue *Queue) Dequeue() (value interface{}, ok bool) { // Peek returns first element of the queue without removing it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to peek. -func (queue *Queue) Peek() (value interface{}, ok bool) { +func (queue *Queue[T]) Peek() (value T, ok bool) { return queue.list.Get(0) } // Empty returns true if queue does not contain any elements. -func (queue *Queue) Empty() bool { +func (queue *Queue[T]) Empty() bool { return queue.list.Empty() } // Size returns number of elements within the queue. -func (queue *Queue) Size() int { +func (queue *Queue[T]) Size() int { return queue.list.Size() } // Clear removes all elements from the queue. -func (queue *Queue) Clear() { +func (queue *Queue[T]) Clear() { queue.list.Clear() } // Values returns all elements in the queue (FIFO order). -func (queue *Queue) Values() []interface{} { +func (queue *Queue[T]) Values() []T { return queue.list.Values() } // String returns a string representation of container -func (queue *Queue) String() string { +func (queue *Queue[T]) String() string { str := "ArrayQueue\n" values := []string{} for _, value := range queue.list.Values() { @@ -83,6 +83,6 @@ func (queue *Queue) String() string { } // Check that the index is within bounds of the list -func (queue *Queue) withinRange(index int) bool { +func (queue *Queue[T]) withinRange(index int) bool { return index >= 0 && index < queue.list.Size() } diff --git a/queues/arrayqueue/arrayqueue_test.go b/queues/arrayqueue/arrayqueue_test.go index b704dbf5..8c295da5 100644 --- a/queues/arrayqueue/arrayqueue_test.go +++ b/queues/arrayqueue/arrayqueue_test.go @@ -6,13 +6,14 @@ package arrayqueue import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestQueueEnqueue(t *testing.T) { - queue := New() + queue := New[int]() if actualValue := queue.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -20,7 +21,7 @@ func TestQueueEnqueue(t *testing.T) { queue.Enqueue(2) queue.Enqueue(3) - if actualValue := queue.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { + if actualValue := queue.Values(); actualValue[0] != 1 || actualValue[1] != 2 || actualValue[2] != 3 { t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") } if actualValue := queue.Empty(); actualValue != false { @@ -35,8 +36,8 @@ func TestQueueEnqueue(t *testing.T) { } func TestQueuePeek(t *testing.T) { - queue := New() - if actualValue, ok := queue.Peek(); actualValue != nil || ok { + queue := New[int]() + if actualValue, ok := queue.Peek(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } queue.Enqueue(1) @@ -48,7 +49,7 @@ func TestQueuePeek(t *testing.T) { } func TestQueueDequeue(t *testing.T) { - queue := New() + queue := New[int]() queue.Enqueue(1) queue.Enqueue(2) queue.Enqueue(3) @@ -62,7 +63,7 @@ func TestQueueDequeue(t *testing.T) { if actualValue, ok := queue.Dequeue(); actualValue != 3 || !ok { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, ok := queue.Dequeue(); actualValue != nil || ok { + if actualValue, ok := queue.Dequeue(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := queue.Empty(); actualValue != true { @@ -74,7 +75,7 @@ func TestQueueDequeue(t *testing.T) { } func TestQueueIteratorOnEmpty(t *testing.T) { - queue := New() + queue := New[int]() it := queue.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty queue") @@ -82,7 +83,7 @@ func TestQueueIteratorOnEmpty(t *testing.T) { } func TestQueueIteratorNext(t *testing.T) { - queue := New() + queue := New[string]() queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") @@ -125,7 +126,7 @@ func TestQueueIteratorNext(t *testing.T) { } func TestQueueIteratorPrev(t *testing.T) { - queue := New() + queue := New[string]() queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") @@ -164,7 +165,7 @@ func TestQueueIteratorPrev(t *testing.T) { } func TestQueueIteratorBegin(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() it.Begin() queue.Enqueue("a") @@ -180,7 +181,7 @@ func TestQueueIteratorBegin(t *testing.T) { } func TestQueueIteratorEnd(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() if index := it.Index(); index != -1 { @@ -207,7 +208,7 @@ func TestQueueIteratorEnd(t *testing.T) { } func TestQueueIteratorFirst(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -224,7 +225,7 @@ func TestQueueIteratorFirst(t *testing.T) { } func TestQueueIteratorLast(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -242,13 +243,13 @@ func TestQueueIteratorLast(t *testing.T) { func TestQueueIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - queue := New() + queue := New[string]() it := queue.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") @@ -257,7 +258,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (not found) { - queue := New() + queue := New[string]() queue.Enqueue("xx") queue.Enqueue("yy") it := queue.Iterator() @@ -268,7 +269,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (found) { - queue := New() + queue := New[string]() queue.Enqueue("aa") queue.Enqueue("bb") queue.Enqueue("cc") @@ -277,13 +278,13 @@ func TestQueueIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -294,13 +295,13 @@ func TestQueueIteratorNextTo(t *testing.T) { func TestQueueIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - queue := New() + queue := New[string]() it := queue.Iterator() it.End() for it.PrevTo(seek) { @@ -310,7 +311,7 @@ func TestQueueIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - queue := New() + queue := New[string]() queue.Enqueue("xx") queue.Enqueue("yy") it := queue.Iterator() @@ -322,7 +323,7 @@ func TestQueueIteratorPrevTo(t *testing.T) { // PrevTo (found) { - queue := New() + queue := New[string]() queue.Enqueue("aa") queue.Enqueue("bb") queue.Enqueue("cc") @@ -331,13 +332,13 @@ func TestQueueIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty queue") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -347,16 +348,14 @@ func TestQueueIteratorPrevTo(t *testing.T) { } func TestQueueSerialization(t *testing.T) { - queue := New() + queue := New[string]() queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", queue.Values()...), "abc"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, queue.Values(), []string{"a", "b", "c"}) if actualValue, expectedValue := queue.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -378,21 +377,22 @@ func TestQueueSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &queue) + err = json.Unmarshal([]byte(`["a","b","c"]`), &queue) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestQueueString(t *testing.T) { - c := New() + c := New[int]() c.Enqueue(1) if !strings.HasPrefix(c.String(), "ArrayQueue") { t.Errorf("String should start with container name") } } -func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { +func benchmarkEnqueue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Enqueue(n) @@ -400,7 +400,7 @@ func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { } } -func benchmarkDequeue(b *testing.B, queue *Queue, size int) { +func benchmarkDequeue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Dequeue() @@ -411,7 +411,7 @@ func benchmarkDequeue(b *testing.B, queue *Queue, size int) { func BenchmarkArrayQueueDequeue100(b *testing.B) { b.StopTimer() size := 100 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -422,7 +422,7 @@ func BenchmarkArrayQueueDequeue100(b *testing.B) { func BenchmarkArrayQueueDequeue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -433,7 +433,7 @@ func BenchmarkArrayQueueDequeue1000(b *testing.B) { func BenchmarkArrayQueueDequeue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -444,7 +444,7 @@ func BenchmarkArrayQueueDequeue10000(b *testing.B) { func BenchmarkArrayQueueDequeue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -455,7 +455,7 @@ func BenchmarkArrayQueueDequeue100000(b *testing.B) { func BenchmarkArrayQueueEnqueue100(b *testing.B) { b.StopTimer() size := 100 - queue := New() + queue := New[int]() b.StartTimer() benchmarkEnqueue(b, queue, size) } @@ -463,7 +463,7 @@ func BenchmarkArrayQueueEnqueue100(b *testing.B) { func BenchmarkArrayQueueEnqueue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -474,7 +474,7 @@ func BenchmarkArrayQueueEnqueue1000(b *testing.B) { func BenchmarkArrayQueueEnqueue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -485,7 +485,7 @@ func BenchmarkArrayQueueEnqueue10000(b *testing.B) { func BenchmarkArrayQueueEnqueue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } diff --git a/queues/arrayqueue/iterator.go b/queues/arrayqueue/iterator.go index 51a30f9a..bc684685 100644 --- a/queues/arrayqueue/iterator.go +++ b/queues/arrayqueue/iterator.go @@ -4,27 +4,27 @@ package arrayqueue -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - queue *Queue +type Iterator[T comparable] struct { + queue *Queue[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (queue *Queue) Iterator() Iterator { - return Iterator{queue: queue, index: -1} +func (queue *Queue[T]) Iterator() *Iterator[T] { + return &Iterator[T]{queue: queue, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.queue.Size() { iterator.index++ } @@ -34,7 +34,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -43,33 +43,33 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { value, _ := iterator.queue.list.Get(iterator.index) return value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.queue.Size() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -77,7 +77,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -86,7 +86,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -100,7 +100,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/queues/arrayqueue/serialization.go b/queues/arrayqueue/serialization.go index a33a82f9..49abb16f 100644 --- a/queues/arrayqueue/serialization.go +++ b/queues/arrayqueue/serialization.go @@ -5,29 +5,29 @@ package arrayqueue import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Queue)(nil) -var _ containers.JSONDeserializer = (*Queue)(nil) +var _ containers.JSONSerializer = (*Queue[int])(nil) +var _ containers.JSONDeserializer = (*Queue[int])(nil) // ToJSON outputs the JSON representation of the queue. -func (queue *Queue) ToJSON() ([]byte, error) { +func (queue *Queue[T]) ToJSON() ([]byte, error) { return queue.list.ToJSON() } // FromJSON populates the queue from the input JSON representation. -func (queue *Queue) FromJSON(data []byte) error { +func (queue *Queue[T]) FromJSON(data []byte) error { return queue.list.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (queue *Queue) UnmarshalJSON(bytes []byte) error { +func (queue *Queue[T]) UnmarshalJSON(bytes []byte) error { return queue.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (queue *Queue) MarshalJSON() ([]byte, error) { +func (queue *Queue[T]) MarshalJSON() ([]byte, error) { return queue.ToJSON() } diff --git a/queues/circularbuffer/circularbuffer.go b/queues/circularbuffer/circularbuffer.go index f74a55b1..ef71fdd4 100644 --- a/queues/circularbuffer/circularbuffer.go +++ b/queues/circularbuffer/circularbuffer.go @@ -15,15 +15,15 @@ import ( "fmt" "strings" - "github.com/emirpasic/gods/queues" + "github.com/emirpasic/gods/v2/queues" ) // Assert Queue implementation -var _ queues.Queue = (*Queue)(nil) +var _ queues.Queue[int] = (*Queue[int])(nil) // Queue holds values in a slice. -type Queue struct { - values []interface{} +type Queue[T comparable] struct { + values []T start int end int full bool @@ -33,17 +33,17 @@ type Queue struct { // New instantiates a new empty queue with the specified size of maximum number of elements that it can hold. // This max size of the buffer cannot be changed. -func New(maxSize int) *Queue { +func New[T comparable](maxSize int) *Queue[T] { if maxSize < 1 { panic("Invalid maxSize, should be at least 1") } - queue := &Queue{maxSize: maxSize} + queue := &Queue[T]{maxSize: maxSize} queue.Clear() return queue } // Enqueue adds a value to the end of the queue -func (queue *Queue) Enqueue(value interface{}) { +func (queue *Queue[T]) Enqueue(value T) { if queue.Full() { queue.Dequeue() } @@ -59,24 +59,19 @@ func (queue *Queue) Enqueue(value interface{}) { queue.size = queue.calculateSize() } -// Dequeue removes first element of the queue and returns it, or nil if queue is empty. +// Dequeue removes first element of the queue and returns it, or the 0-value if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to dequeue. -func (queue *Queue) Dequeue() (value interface{}, ok bool) { +func (queue *Queue[T]) Dequeue() (value T, ok bool) { if queue.Empty() { - return nil, false + return value, false } value, ok = queue.values[queue.start], true - - if value != nil { - queue.values[queue.start] = nil - queue.start = queue.start + 1 - if queue.start >= queue.maxSize { - queue.start = 0 - } - queue.full = false + queue.start = queue.start + 1 + if queue.start >= queue.maxSize { + queue.start = 0 } - + queue.full = false queue.size = queue.size - 1 return @@ -84,31 +79,31 @@ func (queue *Queue) Dequeue() (value interface{}, ok bool) { // Peek returns first element of the queue without removing it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to peek. -func (queue *Queue) Peek() (value interface{}, ok bool) { +func (queue *Queue[T]) Peek() (value T, ok bool) { if queue.Empty() { - return nil, false + return value, false } return queue.values[queue.start], true } // Empty returns true if queue does not contain any elements. -func (queue *Queue) Empty() bool { +func (queue *Queue[T]) Empty() bool { return queue.Size() == 0 } // Full returns true if the queue is full, i.e. has reached the maximum number of elements that it can hold. -func (queue *Queue) Full() bool { +func (queue *Queue[T]) Full() bool { return queue.Size() == queue.maxSize } // Size returns number of elements within the queue. -func (queue *Queue) Size() int { +func (queue *Queue[T]) Size() int { return queue.size } // Clear removes all elements from the queue. -func (queue *Queue) Clear() { - queue.values = make([]interface{}, queue.maxSize, queue.maxSize) +func (queue *Queue[T]) Clear() { + queue.values = make([]T, queue.maxSize, queue.maxSize) queue.start = 0 queue.end = 0 queue.full = false @@ -116,8 +111,8 @@ func (queue *Queue) Clear() { } // Values returns all elements in the queue (FIFO order). -func (queue *Queue) Values() []interface{} { - values := make([]interface{}, queue.Size(), queue.Size()) +func (queue *Queue[T]) Values() []T { + values := make([]T, queue.Size(), queue.Size()) for i := 0; i < queue.Size(); i++ { values[i] = queue.values[(queue.start+i)%queue.maxSize] } @@ -125,7 +120,7 @@ func (queue *Queue) Values() []interface{} { } // String returns a string representation of container -func (queue *Queue) String() string { +func (queue *Queue[T]) String() string { str := "CircularBuffer\n" var values []string for _, value := range queue.Values() { @@ -136,11 +131,11 @@ func (queue *Queue) String() string { } // Check that the index is within bounds of the list -func (queue *Queue) withinRange(index int) bool { +func (queue *Queue[T]) withinRange(index int) bool { return index >= 0 && index < queue.size } -func (queue *Queue) calculateSize() int { +func (queue *Queue[T]) calculateSize() int { if queue.end < queue.start { return queue.maxSize - queue.start + queue.end } else if queue.end == queue.start { diff --git a/queues/circularbuffer/circularbuffer_test.go b/queues/circularbuffer/circularbuffer_test.go index 676ea7eb..d48bb3ae 100644 --- a/queues/circularbuffer/circularbuffer_test.go +++ b/queues/circularbuffer/circularbuffer_test.go @@ -6,13 +6,14 @@ package circularbuffer import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestQueueEnqueue(t *testing.T) { - queue := New(3) + queue := New[int](3) if actualValue := queue.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -20,7 +21,7 @@ func TestQueueEnqueue(t *testing.T) { queue.Enqueue(2) queue.Enqueue(3) - if actualValue := queue.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { + if actualValue := queue.Values(); actualValue[0] != 1 || actualValue[1] != 2 || actualValue[2] != 3 { t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") } if actualValue := queue.Empty(); actualValue != false { @@ -35,8 +36,8 @@ func TestQueueEnqueue(t *testing.T) { } func TestQueuePeek(t *testing.T) { - queue := New(3) - if actualValue, ok := queue.Peek(); actualValue != nil || ok { + queue := New[int](3) + if actualValue, ok := queue.Peek(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } queue.Enqueue(1) @@ -54,7 +55,7 @@ func TestQueueDequeue(t *testing.T) { } } - queue := New(3) + queue := New[int](3) assert(queue.Empty(), true) assert(queue.Empty(), true) assert(queue.Full(), false) @@ -89,7 +90,7 @@ func TestQueueDequeue(t *testing.T) { assert(queue.Empty(), true) assert(queue.Full(), false) - if actualValue, ok := queue.Dequeue(); actualValue != nil || ok { + if actualValue, ok := queue.Dequeue(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } assert(queue.Size(), 0) @@ -106,7 +107,7 @@ func TestQueueDequeueFull(t *testing.T) { } } - queue := New(2) + queue := New[int](2) assert(queue.Empty(), true) assert(queue.Full(), false) assert(queue.Size(), 0) @@ -142,7 +143,7 @@ func TestQueueDequeueFull(t *testing.T) { } assert(queue.Size(), 0) - if actualValue, ok := queue.Dequeue(); actualValue != nil || ok { + if actualValue, ok := queue.Dequeue(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } assert(queue.Empty(), true) @@ -151,7 +152,7 @@ func TestQueueDequeueFull(t *testing.T) { } func TestQueueIteratorOnEmpty(t *testing.T) { - queue := New(3) + queue := New[int](3) it := queue.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty queue") @@ -159,7 +160,7 @@ func TestQueueIteratorOnEmpty(t *testing.T) { } func TestQueueIteratorNext(t *testing.T) { - queue := New(3) + queue := New[string](3) queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") @@ -202,7 +203,7 @@ func TestQueueIteratorNext(t *testing.T) { } func TestQueueIteratorPrev(t *testing.T) { - queue := New(3) + queue := New[string](3) queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") @@ -241,7 +242,7 @@ func TestQueueIteratorPrev(t *testing.T) { } func TestQueueIteratorBegin(t *testing.T) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() it.Begin() queue.Enqueue("a") @@ -257,7 +258,7 @@ func TestQueueIteratorBegin(t *testing.T) { } func TestQueueIteratorEnd(t *testing.T) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() if index := it.Index(); index != -1 { @@ -284,7 +285,7 @@ func TestQueueIteratorEnd(t *testing.T) { } func TestQueueIteratorFirst(t *testing.T) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -301,7 +302,7 @@ func TestQueueIteratorFirst(t *testing.T) { } func TestQueueIteratorLast(t *testing.T) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -319,13 +320,13 @@ func TestQueueIteratorLast(t *testing.T) { func TestQueueIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") @@ -334,7 +335,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (not found) { - queue := New(3) + queue := New[string](3) queue.Enqueue("xx") queue.Enqueue("yy") it := queue.Iterator() @@ -345,7 +346,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (found) { - queue := New(3) + queue := New[string](3) queue.Enqueue("aa") queue.Enqueue("bb") queue.Enqueue("cc") @@ -354,13 +355,13 @@ func TestQueueIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -371,13 +372,13 @@ func TestQueueIteratorNextTo(t *testing.T) { func TestQueueIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - queue := New(3) + queue := New[string](3) it := queue.Iterator() it.End() for it.PrevTo(seek) { @@ -387,7 +388,7 @@ func TestQueueIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - queue := New(3) + queue := New[string](3) queue.Enqueue("xx") queue.Enqueue("yy") it := queue.Iterator() @@ -399,7 +400,7 @@ func TestQueueIteratorPrevTo(t *testing.T) { // PrevTo (found) { - queue := New(3) + queue := New[string](3) queue.Enqueue("aa") queue.Enqueue("bb") queue.Enqueue("cc") @@ -408,13 +409,13 @@ func TestQueueIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty queue") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -430,7 +431,7 @@ func TestQueueIterator(t *testing.T) { } } - queue := New(2) + queue := New[string](2) queue.Enqueue("a") queue.Enqueue("b") @@ -482,16 +483,14 @@ func TestQueueIterator(t *testing.T) { } func TestQueueSerialization(t *testing.T) { - queue := New(3) + queue := New[string](3) queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", queue.Values()...), "abc"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, queue.Values(), []string{"a", "b", "c"}) if actualValue, expectedValue := queue.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -513,21 +512,22 @@ func TestQueueSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &queue) + err = json.Unmarshal([]byte(`["a","b","c"]`), &queue) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestQueueString(t *testing.T) { - c := New(3) + c := New[int](3) c.Enqueue(1) if !strings.HasPrefix(c.String(), "CircularBuffer") { t.Errorf("String should start with container name") } } -func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { +func benchmarkEnqueue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Enqueue(n) @@ -535,7 +535,7 @@ func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { } } -func benchmarkDequeue(b *testing.B, queue *Queue, size int) { +func benchmarkDequeue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Dequeue() @@ -546,7 +546,7 @@ func benchmarkDequeue(b *testing.B, queue *Queue, size int) { func BenchmarkArrayQueueDequeue100(b *testing.B) { b.StopTimer() size := 100 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -557,7 +557,7 @@ func BenchmarkArrayQueueDequeue100(b *testing.B) { func BenchmarkArrayQueueDequeue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -568,7 +568,7 @@ func BenchmarkArrayQueueDequeue1000(b *testing.B) { func BenchmarkArrayQueueDequeue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -579,7 +579,7 @@ func BenchmarkArrayQueueDequeue10000(b *testing.B) { func BenchmarkArrayQueueDequeue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -590,7 +590,7 @@ func BenchmarkArrayQueueDequeue100000(b *testing.B) { func BenchmarkArrayQueueEnqueue100(b *testing.B) { b.StopTimer() size := 100 - queue := New(3) + queue := New[int](3) b.StartTimer() benchmarkEnqueue(b, queue, size) } @@ -598,7 +598,7 @@ func BenchmarkArrayQueueEnqueue100(b *testing.B) { func BenchmarkArrayQueueEnqueue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -609,7 +609,7 @@ func BenchmarkArrayQueueEnqueue1000(b *testing.B) { func BenchmarkArrayQueueEnqueue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -620,7 +620,7 @@ func BenchmarkArrayQueueEnqueue10000(b *testing.B) { func BenchmarkArrayQueueEnqueue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New(3) + queue := New[int](3) for n := 0; n < size; n++ { queue.Enqueue(n) } diff --git a/queues/circularbuffer/iterator.go b/queues/circularbuffer/iterator.go index dae30ce4..be6af87d 100644 --- a/queues/circularbuffer/iterator.go +++ b/queues/circularbuffer/iterator.go @@ -4,27 +4,27 @@ package circularbuffer -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - queue *Queue +type Iterator[T comparable] struct { + queue *Queue[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (queue *Queue) Iterator() Iterator { - return Iterator{queue: queue, index: -1} +func (queue *Queue[T]) Iterator() *Iterator[T] { + return &Iterator[T]{queue: queue, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.queue.size { iterator.index++ } @@ -34,7 +34,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -43,7 +43,7 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { index := (iterator.index + iterator.queue.start) % iterator.queue.maxSize value := iterator.queue.values[index] return value @@ -51,26 +51,26 @@ func (iterator *Iterator) Value() interface{} { // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.queue.size } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -78,7 +78,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -87,7 +87,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -101,7 +101,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/queues/circularbuffer/serialization.go b/queues/circularbuffer/serialization.go index da2543d0..a020bac0 100644 --- a/queues/circularbuffer/serialization.go +++ b/queues/circularbuffer/serialization.go @@ -6,21 +6,22 @@ package circularbuffer import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Queue)(nil) -var _ containers.JSONDeserializer = (*Queue)(nil) +var _ containers.JSONSerializer = (*Queue[int])(nil) +var _ containers.JSONDeserializer = (*Queue[int])(nil) // ToJSON outputs the JSON representation of queue's elements. -func (queue *Queue) ToJSON() ([]byte, error) { +func (queue *Queue[T]) ToJSON() ([]byte, error) { return json.Marshal(queue.values[:queue.maxSize]) } // FromJSON populates list's elements from the input JSON representation. -func (queue *Queue) FromJSON(data []byte) error { - var values []interface{} +func (queue *Queue[T]) FromJSON(data []byte) error { + var values []T err := json.Unmarshal(data, &values) if err == nil { for _, value := range values { @@ -31,11 +32,11 @@ func (queue *Queue) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (queue *Queue) UnmarshalJSON(bytes []byte) error { +func (queue *Queue[T]) UnmarshalJSON(bytes []byte) error { return queue.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (queue *Queue) MarshalJSON() ([]byte, error) { +func (queue *Queue[T]) MarshalJSON() ([]byte, error) { return queue.ToJSON() } diff --git a/queues/linkedlistqueue/iterator.go b/queues/linkedlistqueue/iterator.go index cf47b191..29bfffda 100644 --- a/queues/linkedlistqueue/iterator.go +++ b/queues/linkedlistqueue/iterator.go @@ -4,27 +4,27 @@ package linkedlistqueue -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.IteratorWithIndex = (*Iterator)(nil) +var _ containers.IteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - queue *Queue +type Iterator[T comparable] struct { + queue *Queue[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (queue *Queue) Iterator() Iterator { - return Iterator{queue: queue, index: -1} +func (queue *Queue[T]) Iterator() *Iterator[T] { + return &Iterator[T]{queue: queue, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.queue.Size() { iterator.index++ } @@ -33,27 +33,27 @@ func (iterator *Iterator) Next() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { value, _ := iterator.queue.list.Get(iterator.index) return value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -62,7 +62,7 @@ func (iterator *Iterator) First() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/queues/linkedlistqueue/linkedlistqueue.go b/queues/linkedlistqueue/linkedlistqueue.go index fdb94636..32ad9f32 100644 --- a/queues/linkedlistqueue/linkedlistqueue.go +++ b/queues/linkedlistqueue/linkedlistqueue.go @@ -13,31 +13,31 @@ import ( "fmt" "strings" - "github.com/emirpasic/gods/lists/singlylinkedlist" - "github.com/emirpasic/gods/queues" + "github.com/emirpasic/gods/v2/lists/singlylinkedlist" + "github.com/emirpasic/gods/v2/queues" ) // Assert Queue implementation -var _ queues.Queue = (*Queue)(nil) +var _ queues.Queue[int] = (*Queue[int])(nil) // Queue holds elements in a singly-linked-list -type Queue struct { - list *singlylinkedlist.List +type Queue[T comparable] struct { + list *singlylinkedlist.List[T] } // New instantiates a new empty queue -func New() *Queue { - return &Queue{list: &singlylinkedlist.List{}} +func New[T comparable]() *Queue[T] { + return &Queue[T]{list: singlylinkedlist.New[T]()} } // Enqueue adds a value to the end of the queue -func (queue *Queue) Enqueue(value interface{}) { +func (queue *Queue[T]) Enqueue(value T) { queue.list.Add(value) } // Dequeue removes first element of the queue and returns it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to dequeue. -func (queue *Queue) Dequeue() (value interface{}, ok bool) { +func (queue *Queue[T]) Dequeue() (value T, ok bool) { value, ok = queue.list.Get(0) if ok { queue.list.Remove(0) @@ -47,32 +47,32 @@ func (queue *Queue) Dequeue() (value interface{}, ok bool) { // Peek returns first element of the queue without removing it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to peek. -func (queue *Queue) Peek() (value interface{}, ok bool) { +func (queue *Queue[T]) Peek() (value T, ok bool) { return queue.list.Get(0) } // Empty returns true if queue does not contain any elements. -func (queue *Queue) Empty() bool { +func (queue *Queue[T]) Empty() bool { return queue.list.Empty() } // Size returns number of elements within the queue. -func (queue *Queue) Size() int { +func (queue *Queue[T]) Size() int { return queue.list.Size() } // Clear removes all elements from the queue. -func (queue *Queue) Clear() { +func (queue *Queue[T]) Clear() { queue.list.Clear() } // Values returns all elements in the queue (FIFO order). -func (queue *Queue) Values() []interface{} { +func (queue *Queue[T]) Values() []T { return queue.list.Values() } // String returns a string representation of container -func (queue *Queue) String() string { +func (queue *Queue[T]) String() string { str := "LinkedListQueue\n" values := []string{} for _, value := range queue.list.Values() { @@ -83,6 +83,6 @@ func (queue *Queue) String() string { } // Check that the index is within bounds of the list -func (queue *Queue) withinRange(index int) bool { +func (queue *Queue[T]) withinRange(index int) bool { return index >= 0 && index < queue.list.Size() } diff --git a/queues/linkedlistqueue/linkedlistqueue_test.go b/queues/linkedlistqueue/linkedlistqueue_test.go index e8e7c749..60ae2f75 100644 --- a/queues/linkedlistqueue/linkedlistqueue_test.go +++ b/queues/linkedlistqueue/linkedlistqueue_test.go @@ -6,13 +6,14 @@ package linkedlistqueue import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestQueueEnqueue(t *testing.T) { - queue := New() + queue := New[int]() if actualValue := queue.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -20,7 +21,7 @@ func TestQueueEnqueue(t *testing.T) { queue.Enqueue(2) queue.Enqueue(3) - if actualValue := queue.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { + if actualValue := queue.Values(); actualValue[0] != 1 || actualValue[1] != 2 || actualValue[2] != 3 { t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") } if actualValue := queue.Empty(); actualValue != false { @@ -35,8 +36,8 @@ func TestQueueEnqueue(t *testing.T) { } func TestQueuePeek(t *testing.T) { - queue := New() - if actualValue, ok := queue.Peek(); actualValue != nil || ok { + queue := New[int]() + if actualValue, ok := queue.Peek(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } queue.Enqueue(1) @@ -48,7 +49,7 @@ func TestQueuePeek(t *testing.T) { } func TestQueueDequeue(t *testing.T) { - queue := New() + queue := New[int]() queue.Enqueue(1) queue.Enqueue(2) queue.Enqueue(3) @@ -62,7 +63,7 @@ func TestQueueDequeue(t *testing.T) { if actualValue, ok := queue.Dequeue(); actualValue != 3 || !ok { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, ok := queue.Dequeue(); actualValue != nil || ok { + if actualValue, ok := queue.Dequeue(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := queue.Empty(); actualValue != true { @@ -74,7 +75,7 @@ func TestQueueDequeue(t *testing.T) { } func TestQueueIteratorOnEmpty(t *testing.T) { - queue := New() + queue := New[int]() it := queue.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty queue") @@ -82,7 +83,7 @@ func TestQueueIteratorOnEmpty(t *testing.T) { } func TestQueueIteratorNext(t *testing.T) { - queue := New() + queue := New[string]() queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") @@ -125,7 +126,7 @@ func TestQueueIteratorNext(t *testing.T) { } func TestQueueIteratorBegin(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() it.Begin() queue.Enqueue("a") @@ -141,7 +142,7 @@ func TestQueueIteratorBegin(t *testing.T) { } func TestQueueIteratorFirst(t *testing.T) { - queue := New() + queue := New[string]() it := queue.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -159,13 +160,13 @@ func TestQueueIteratorFirst(t *testing.T) { func TestQueueIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - queue := New() + queue := New[string]() it := queue.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") @@ -174,7 +175,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (not found) { - queue := New() + queue := New[string]() queue.Enqueue("xx") queue.Enqueue("yy") it := queue.Iterator() @@ -185,7 +186,7 @@ func TestQueueIteratorNextTo(t *testing.T) { // NextTo (found) { - queue := New() + queue := New[string]() queue.Enqueue("aa") queue.Enqueue("bb") queue.Enqueue("cc") @@ -194,13 +195,13 @@ func TestQueueIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty queue") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -210,16 +211,14 @@ func TestQueueIteratorNextTo(t *testing.T) { } func TestQueueSerialization(t *testing.T) { - queue := New() + queue := New[string]() queue.Enqueue("a") queue.Enqueue("b") queue.Enqueue("c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", queue.Values()...), "abc"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, queue.Values(), []string{"a", "b", "c"}) if actualValue, expectedValue := queue.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -241,21 +240,22 @@ func TestQueueSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &queue) + err = json.Unmarshal([]byte(`["a","b","c"]`), &queue) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestQueueString(t *testing.T) { - c := New() + c := New[int]() c.Enqueue(1) if !strings.HasPrefix(c.String(), "LinkedListQueue") { t.Errorf("String should start with container name") } } -func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { +func benchmarkEnqueue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Enqueue(n) @@ -263,7 +263,7 @@ func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { } } -func benchmarkDequeue(b *testing.B, queue *Queue, size int) { +func benchmarkDequeue(b *testing.B, queue *Queue[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Dequeue() @@ -274,7 +274,7 @@ func benchmarkDequeue(b *testing.B, queue *Queue, size int) { func BenchmarkArrayQueueDequeue100(b *testing.B) { b.StopTimer() size := 100 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -285,7 +285,7 @@ func BenchmarkArrayQueueDequeue100(b *testing.B) { func BenchmarkArrayQueueDequeue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -296,7 +296,7 @@ func BenchmarkArrayQueueDequeue1000(b *testing.B) { func BenchmarkArrayQueueDequeue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -307,7 +307,7 @@ func BenchmarkArrayQueueDequeue10000(b *testing.B) { func BenchmarkArrayQueueDequeue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -318,7 +318,7 @@ func BenchmarkArrayQueueDequeue100000(b *testing.B) { func BenchmarkArrayQueueEnqueue100(b *testing.B) { b.StopTimer() size := 100 - queue := New() + queue := New[int]() b.StartTimer() benchmarkEnqueue(b, queue, size) } @@ -326,7 +326,7 @@ func BenchmarkArrayQueueEnqueue100(b *testing.B) { func BenchmarkArrayQueueEnqueue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -337,7 +337,7 @@ func BenchmarkArrayQueueEnqueue1000(b *testing.B) { func BenchmarkArrayQueueEnqueue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } @@ -348,7 +348,7 @@ func BenchmarkArrayQueueEnqueue10000(b *testing.B) { func BenchmarkArrayQueueEnqueue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := New() + queue := New[int]() for n := 0; n < size; n++ { queue.Enqueue(n) } diff --git a/queues/linkedlistqueue/serialization.go b/queues/linkedlistqueue/serialization.go index 2b34c8e4..8e9c157d 100644 --- a/queues/linkedlistqueue/serialization.go +++ b/queues/linkedlistqueue/serialization.go @@ -5,29 +5,29 @@ package linkedlistqueue import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Queue)(nil) -var _ containers.JSONDeserializer = (*Queue)(nil) +var _ containers.JSONSerializer = (*Queue[int])(nil) +var _ containers.JSONDeserializer = (*Queue[int])(nil) // ToJSON outputs the JSON representation of the queue. -func (queue *Queue) ToJSON() ([]byte, error) { +func (queue *Queue[T]) ToJSON() ([]byte, error) { return queue.list.ToJSON() } // FromJSON populates the queue from the input JSON representation. -func (queue *Queue) FromJSON(data []byte) error { +func (queue *Queue[T]) FromJSON(data []byte) error { return queue.list.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (queue *Queue) UnmarshalJSON(bytes []byte) error { +func (queue *Queue[T]) UnmarshalJSON(bytes []byte) error { return queue.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (queue *Queue) MarshalJSON() ([]byte, error) { +func (queue *Queue[T]) MarshalJSON() ([]byte, error) { return queue.ToJSON() } diff --git a/queues/priorityqueue/iterator.go b/queues/priorityqueue/iterator.go index ea6181a2..1a55e07f 100644 --- a/queues/priorityqueue/iterator.go +++ b/queues/priorityqueue/iterator.go @@ -5,73 +5,73 @@ package priorityqueue import ( - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/trees/binaryheap" + "github.com/emirpasic/gods/v2/containers" + "github.com/emirpasic/gods/v2/trees/binaryheap" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - iterator binaryheap.Iterator +type Iterator[T comparable] struct { + iterator *binaryheap.Iterator[T] } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (queue *Queue) Iterator() Iterator { - return Iterator{iterator: queue.heap.Iterator()} +func (queue *Queue[T]) Iterator() *Iterator[T] { + return &Iterator[T]{iterator: queue.heap.Iterator()} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { return iterator.iterator.Next() } // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { return iterator.iterator.Prev() } // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.iterator.Value() } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.iterator.Index() } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.iterator.End() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { return iterator.iterator.First() } // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { return iterator.iterator.Last() } @@ -79,7 +79,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { return iterator.iterator.NextTo(f) } @@ -87,6 +87,6 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { return iterator.iterator.PrevTo(f) } diff --git a/queues/priorityqueue/priorityqueue.go b/queues/priorityqueue/priorityqueue.go index 3a7e6f22..4427f75e 100644 --- a/queues/priorityqueue/priorityqueue.go +++ b/queues/priorityqueue/priorityqueue.go @@ -16,66 +16,72 @@ package priorityqueue import ( + "cmp" "fmt" - "github.com/emirpasic/gods/queues" - "github.com/emirpasic/gods/trees/binaryheap" - "github.com/emirpasic/gods/utils" "strings" + + "github.com/emirpasic/gods/v2/queues" + "github.com/emirpasic/gods/v2/trees/binaryheap" + "github.com/emirpasic/gods/v2/utils" ) // Assert Queue implementation -var _ queues.Queue = (*Queue)(nil) +var _ queues.Queue[int] = (*Queue[int])(nil) // Queue holds elements in an array-list -type Queue struct { - heap *binaryheap.Heap - Comparator utils.Comparator +type Queue[T comparable] struct { + heap *binaryheap.Heap[T] + Comparator utils.Comparator[T] +} + +func New[T cmp.Ordered]() *Queue[T] { + return NewWith[T](cmp.Compare[T]) } // NewWith instantiates a new empty queue with the custom comparator. -func NewWith(comparator utils.Comparator) *Queue { - return &Queue{heap: binaryheap.NewWith(comparator), Comparator: comparator} +func NewWith[T comparable](comparator utils.Comparator[T]) *Queue[T] { + return &Queue[T]{heap: binaryheap.NewWith(comparator), Comparator: comparator} } // Enqueue adds a value to the end of the queue -func (queue *Queue) Enqueue(value interface{}) { +func (queue *Queue[T]) Enqueue(value T) { queue.heap.Push(value) } // Dequeue removes first element of the queue and returns it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to dequeue. -func (queue *Queue) Dequeue() (value interface{}, ok bool) { +func (queue *Queue[T]) Dequeue() (value T, ok bool) { return queue.heap.Pop() } // Peek returns top element on the queue without removing it, or nil if queue is empty. // Second return parameter is true, unless the queue was empty and there was nothing to peek. -func (queue *Queue) Peek() (value interface{}, ok bool) { +func (queue *Queue[T]) Peek() (value T, ok bool) { return queue.heap.Peek() } // Empty returns true if queue does not contain any elements. -func (queue *Queue) Empty() bool { +func (queue *Queue[T]) Empty() bool { return queue.heap.Empty() } // Size returns number of elements within the queue. -func (queue *Queue) Size() int { +func (queue *Queue[T]) Size() int { return queue.heap.Size() } // Clear removes all elements from the queue. -func (queue *Queue) Clear() { +func (queue *Queue[T]) Clear() { queue.heap.Clear() } // Values returns all elements in the queue. -func (queue *Queue) Values() []interface{} { +func (queue *Queue[T]) Values() []T { return queue.heap.Values() } // String returns a string representation of container -func (queue *Queue) String() string { +func (queue *Queue[T]) String() string { str := "PriorityQueue\n" values := make([]string, queue.heap.Size(), queue.heap.Size()) for index, value := range queue.heap.Values() { diff --git a/queues/priorityqueue/priorityqueue_test.go b/queues/priorityqueue/priorityqueue_test.go index 6c0db896..1a21e571 100644 --- a/queues/priorityqueue/priorityqueue_test.go +++ b/queues/priorityqueue/priorityqueue_test.go @@ -5,9 +5,9 @@ package priorityqueue import ( + "cmp" "encoding/json" "fmt" - "github.com/emirpasic/gods/utils" "math/rand" "strings" "testing" @@ -23,15 +23,12 @@ func (element Element) String() string { } // Comparator function (sort by priority value in descending order) -func byPriority(a, b interface{}) int { - return -utils.IntComparator( // Note "-" for descending order - a.(Element).priority, - b.(Element).priority, - ) +func byPriority(a, b Element) int { + return -cmp.Compare(a.priority, b.priority) // Note "-" for descending order } func TestBinaryQueueEnqueue(t *testing.T) { - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) if actualValue := queue.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) @@ -53,15 +50,15 @@ func TestBinaryQueueEnqueue(t *testing.T) { value := it.Value() switch index { case 0: - if actualValue, expectedValue := value.(Element).name, "c"; actualValue != expectedValue { + if actualValue, expectedValue := value.name, "c"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 1: - if actualValue, expectedValue := value.(Element).name, "b"; actualValue != expectedValue { + if actualValue, expectedValue := value.name, "b"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 2: - if actualValue, expectedValue := value.(Element).name, "a"; actualValue != expectedValue { + if actualValue, expectedValue := value.name, "a"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } default: @@ -72,13 +69,13 @@ func TestBinaryQueueEnqueue(t *testing.T) { } } - if actualValue := queue.Values(); actualValue[0].(Element).name != "c" || actualValue[1].(Element).name != "b" || actualValue[2].(Element).name != "a" { + if actualValue := queue.Values(); actualValue[0].name != "c" || actualValue[1].name != "b" || actualValue[2].name != "a" { t.Errorf("Got %v expected %v", actualValue, `[{3 c} {2 b} {1 a}]`) } } func TestBinaryQueueEnqueueBulk(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() queue.Enqueue(15) queue.Enqueue(20) @@ -109,7 +106,7 @@ func TestBinaryQueueEnqueueBulk(t *testing.T) { } func TestBinaryQueueDequeue(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() if actualValue := queue.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) @@ -126,7 +123,7 @@ func TestBinaryQueueDequeue(t *testing.T) { if actualValue, ok := queue.Dequeue(); actualValue != 3 || !ok { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, ok := queue.Dequeue(); actualValue != nil || ok { + if actualValue, ok := queue.Dequeue(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := queue.Empty(); actualValue != true { @@ -138,7 +135,7 @@ func TestBinaryQueueDequeue(t *testing.T) { } func TestBinaryQueueRandom(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() rand.Seed(3) for i := 0; i < 10000; i++ { @@ -149,7 +146,7 @@ func TestBinaryQueueRandom(t *testing.T) { prev, _ := queue.Dequeue() for !queue.Empty() { curr, _ := queue.Dequeue() - if prev.(int) > curr.(int) { + if prev > curr { t.Errorf("Queue property invalidated. prev: %v current: %v", prev, curr) } prev = curr @@ -157,7 +154,7 @@ func TestBinaryQueueRandom(t *testing.T) { } func TestBinaryQueueIteratorOnEmpty(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() it := queue.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty queue") @@ -165,7 +162,7 @@ func TestBinaryQueueIteratorOnEmpty(t *testing.T) { } func TestBinaryQueueIteratorNext(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() queue.Enqueue(3) queue.Enqueue(2) queue.Enqueue(1) @@ -202,7 +199,7 @@ func TestBinaryQueueIteratorNext(t *testing.T) { } func TestBinaryQueueIteratorPrev(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() queue.Enqueue(3) queue.Enqueue(2) queue.Enqueue(1) @@ -241,7 +238,7 @@ func TestBinaryQueueIteratorPrev(t *testing.T) { } func TestBinaryQueueIteratorBegin(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() it := queue.Iterator() it.Begin() queue.Enqueue(2) @@ -257,7 +254,7 @@ func TestBinaryQueueIteratorBegin(t *testing.T) { } func TestBinaryQueueIteratorEnd(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() it := queue.Iterator() if index := it.Index(); index != -1 { @@ -284,7 +281,7 @@ func TestBinaryQueueIteratorEnd(t *testing.T) { } func TestBinaryQueueIteratorFirst(t *testing.T) { - queue := NewWith(utils.IntComparator) + queue := New[int]() it := queue.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -301,7 +298,7 @@ func TestBinaryQueueIteratorFirst(t *testing.T) { } func TestBinaryQueueIteratorLast(t *testing.T) { - tree := NewWith(utils.IntComparator) + tree := New[int]() it := tree.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -319,13 +316,13 @@ func TestBinaryQueueIteratorLast(t *testing.T) { func TestBinaryQueueIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - tree := NewWith(utils.StringComparator) + tree := New[string]() it := tree.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") @@ -334,7 +331,7 @@ func TestBinaryQueueIteratorNextTo(t *testing.T) { // NextTo (not found) { - tree := NewWith(utils.StringComparator) + tree := New[string]() tree.Enqueue("xx") tree.Enqueue("yy") it := tree.Iterator() @@ -345,7 +342,7 @@ func TestBinaryQueueIteratorNextTo(t *testing.T) { // NextTo (found) { - tree := NewWith(utils.StringComparator) + tree := New[string]() tree.Enqueue("aa") tree.Enqueue("bb") tree.Enqueue("cc") @@ -354,13 +351,13 @@ func TestBinaryQueueIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -371,13 +368,13 @@ func TestBinaryQueueIteratorNextTo(t *testing.T) { func TestBinaryQueueIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - tree := NewWith(utils.StringComparator) + tree := New[string]() it := tree.Iterator() it.End() for it.PrevTo(seek) { @@ -387,7 +384,7 @@ func TestBinaryQueueIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - tree := NewWith(utils.StringComparator) + tree := New[string]() tree.Enqueue("xx") tree.Enqueue("yy") it := tree.Iterator() @@ -399,7 +396,7 @@ func TestBinaryQueueIteratorPrevTo(t *testing.T) { // PrevTo (found) { - tree := NewWith(utils.StringComparator) + tree := New[string]() tree.Enqueue("aa") tree.Enqueue("bb") tree.Enqueue("cc") @@ -408,13 +405,13 @@ func TestBinaryQueueIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -424,7 +421,7 @@ func TestBinaryQueueIteratorPrevTo(t *testing.T) { } func TestBinaryQueueSerialization(t *testing.T) { - queue := NewWith(utils.StringComparator) + queue := New[string]() queue.Enqueue("c") queue.Enqueue("b") @@ -432,7 +429,7 @@ func TestBinaryQueueSerialization(t *testing.T) { var err error assert := func() { - if actualValue := queue.Values(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { + if actualValue := queue.Values(); actualValue[0] != "a" || actualValue[1] != "b" || actualValue[2] != "c" { t.Errorf("Got %v expected %v", actualValue, "[1,3,2]") } if actualValue := queue.Size(); actualValue != 3 { @@ -459,21 +456,22 @@ func TestBinaryQueueSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &queue) + err = json.Unmarshal([]byte(`["a","b","c"]`), &queue) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestBTreeString(t *testing.T) { - c := NewWith(byPriority) + c := New[int]() c.Enqueue(1) if !strings.HasPrefix(c.String(), "PriorityQueue") { t.Errorf("String should start with container name") } } -func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { +func benchmarkEnqueue(b *testing.B, queue *Queue[Element], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Enqueue(Element{}) @@ -481,7 +479,7 @@ func benchmarkEnqueue(b *testing.B, queue *Queue, size int) { } } -func benchmarkDequeue(b *testing.B, queue *Queue, size int) { +func benchmarkDequeue(b *testing.B, queue *Queue[Element], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { queue.Dequeue() @@ -492,7 +490,7 @@ func benchmarkDequeue(b *testing.B, queue *Queue, size int) { func BenchmarkBinaryQueueDequeue100(b *testing.B) { b.StopTimer() size := 100 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -503,7 +501,7 @@ func BenchmarkBinaryQueueDequeue100(b *testing.B) { func BenchmarkBinaryQueueDequeue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -514,7 +512,7 @@ func BenchmarkBinaryQueueDequeue1000(b *testing.B) { func BenchmarkBinaryQueueDequeue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -525,7 +523,7 @@ func BenchmarkBinaryQueueDequeue10000(b *testing.B) { func BenchmarkBinaryQueueDequeue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -544,7 +542,7 @@ func BenchmarkBinaryQueueEnqueue100(b *testing.B) { func BenchmarkBinaryQueueEnqueue1000(b *testing.B) { b.StopTimer() size := 1000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -555,7 +553,7 @@ func BenchmarkBinaryQueueEnqueue1000(b *testing.B) { func BenchmarkBinaryQueueEnqueue10000(b *testing.B) { b.StopTimer() size := 10000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } @@ -566,7 +564,7 @@ func BenchmarkBinaryQueueEnqueue10000(b *testing.B) { func BenchmarkBinaryQueueEnqueue100000(b *testing.B) { b.StopTimer() size := 100000 - queue := NewWith(byPriority) + queue := NewWith[Element](byPriority) for n := 0; n < size; n++ { queue.Enqueue(Element{}) } diff --git a/queues/priorityqueue/serialization.go b/queues/priorityqueue/serialization.go index 6072a168..f22548df 100644 --- a/queues/priorityqueue/serialization.go +++ b/queues/priorityqueue/serialization.go @@ -5,29 +5,29 @@ package priorityqueue import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Queue)(nil) -var _ containers.JSONDeserializer = (*Queue)(nil) +var _ containers.JSONSerializer = (*Queue[int])(nil) +var _ containers.JSONDeserializer = (*Queue[int])(nil) // ToJSON outputs the JSON representation of the queue. -func (queue *Queue) ToJSON() ([]byte, error) { +func (queue *Queue[T]) ToJSON() ([]byte, error) { return queue.heap.ToJSON() } // FromJSON populates the queue from the input JSON representation. -func (queue *Queue) FromJSON(data []byte) error { +func (queue *Queue[T]) FromJSON(data []byte) error { return queue.heap.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (queue *Queue) UnmarshalJSON(bytes []byte) error { +func (queue *Queue[T]) UnmarshalJSON(bytes []byte) error { return queue.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (queue *Queue) MarshalJSON() ([]byte, error) { +func (queue *Queue[T]) MarshalJSON() ([]byte, error) { return queue.ToJSON() } diff --git a/queues/queues.go b/queues/queues.go index 80239d48..6f2deb43 100644 --- a/queues/queues.go +++ b/queues/queues.go @@ -10,15 +10,15 @@ // Reference: https://en.wikipedia.org/wiki/Queue_(abstract_data_type) package queues -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Queue interface that all queues implement -type Queue interface { - Enqueue(value interface{}) - Dequeue() (value interface{}, ok bool) - Peek() (value interface{}, ok bool) +type Queue[T comparable] interface { + Enqueue(value T) + Dequeue() (value T, ok bool) + Peek() (value T, ok bool) - containers.Container + containers.Container[T] // Empty() bool // Size() int // Clear() diff --git a/sets/hashset/hashset.go b/sets/hashset/hashset.go index 94399288..32523abf 100644 --- a/sets/hashset/hashset.go +++ b/sets/hashset/hashset.go @@ -11,23 +11,24 @@ package hashset import ( "fmt" - "github.com/emirpasic/gods/sets" "strings" + + "github.com/emirpasic/gods/v2/sets" ) // Assert Set implementation -var _ sets.Set = (*Set)(nil) +var _ sets.Set[int] = (*Set[int])(nil) // Set holds elements in go's native map -type Set struct { - items map[interface{}]struct{} +type Set[T comparable] struct { + items map[T]struct{} } var itemExists = struct{}{} // New instantiates a new empty set and adds the passed values, if any, to the set -func New(values ...interface{}) *Set { - set := &Set{items: make(map[interface{}]struct{})} +func New[T comparable](values ...T) *Set[T] { + set := &Set[T]{items: make(map[T]struct{})} if len(values) > 0 { set.Add(values...) } @@ -35,14 +36,14 @@ func New(values ...interface{}) *Set { } // Add adds the items (one or more) to the set. -func (set *Set) Add(items ...interface{}) { +func (set *Set[T]) Add(items ...T) { for _, item := range items { set.items[item] = itemExists } } // Remove removes the items (one or more) from the set. -func (set *Set) Remove(items ...interface{}) { +func (set *Set[T]) Remove(items ...T) { for _, item := range items { delete(set.items, item) } @@ -51,7 +52,7 @@ func (set *Set) Remove(items ...interface{}) { // Contains check if items (one or more) are present in the set. // All items have to be present in the set for the method to return true. // Returns true if no arguments are passed at all, i.e. set is always superset of empty set. -func (set *Set) Contains(items ...interface{}) bool { +func (set *Set[T]) Contains(items ...T) bool { for _, item := range items { if _, contains := set.items[item]; !contains { return false @@ -61,23 +62,23 @@ func (set *Set) Contains(items ...interface{}) bool { } // Empty returns true if set does not contain any elements. -func (set *Set) Empty() bool { +func (set *Set[T]) Empty() bool { return set.Size() == 0 } // Size returns number of elements within the set. -func (set *Set) Size() int { +func (set *Set[T]) Size() int { return len(set.items) } // Clear clears all values in the set. -func (set *Set) Clear() { - set.items = make(map[interface{}]struct{}) +func (set *Set[T]) Clear() { + set.items = make(map[T]struct{}) } // Values returns all items in the set. -func (set *Set) Values() []interface{} { - values := make([]interface{}, set.Size()) +func (set *Set[T]) Values() []T { + values := make([]T, set.Size()) count := 0 for item := range set.items { values[count] = item @@ -87,7 +88,7 @@ func (set *Set) Values() []interface{} { } // String returns a string representation of container -func (set *Set) String() string { +func (set *Set[T]) String() string { str := "HashSet\n" items := []string{} for k := range set.items { @@ -100,8 +101,8 @@ func (set *Set) String() string { // Intersection returns the intersection between two sets. // The new set consists of all elements that are both in "set" and "another". // Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) -func (set *Set) Intersection(another *Set) *Set { - result := New() +func (set *Set[T]) Intersection(another *Set[T]) *Set[T] { + result := New[T]() // Iterate over smaller set (optimization) if set.Size() <= another.Size() { @@ -124,8 +125,8 @@ func (set *Set) Intersection(another *Set) *Set { // Union returns the union of two sets. // The new set consists of all elements that are in "set" or "another" (possibly both). // Ref: https://en.wikipedia.org/wiki/Union_(set_theory) -func (set *Set) Union(another *Set) *Set { - result := New() +func (set *Set[T]) Union(another *Set[T]) *Set[T] { + result := New[T]() for item := range set.items { result.Add(item) @@ -140,8 +141,8 @@ func (set *Set) Union(another *Set) *Set { // Difference returns the difference between two sets. // The new set consists of all elements that are in "set" but not in "another". // Ref: https://proofwiki.org/wiki/Definition:Set_Difference -func (set *Set) Difference(another *Set) *Set { - result := New() +func (set *Set[T]) Difference(another *Set[T]) *Set[T] { + result := New[T]() for item := range set.items { if _, contains := another.items[item]; !contains { diff --git a/sets/hashset/hashset_test.go b/sets/hashset/hashset_test.go index fe516b79..21d84297 100644 --- a/sets/hashset/hashset_test.go +++ b/sets/hashset/hashset_test.go @@ -28,7 +28,7 @@ func TestSetNew(t *testing.T) { } func TestSetAdd(t *testing.T) { - set := New() + set := New[int]() set.Add() set.Add(1) set.Add(2) @@ -43,7 +43,7 @@ func TestSetAdd(t *testing.T) { } func TestSetContains(t *testing.T) { - set := New() + set := New[int]() set.Add(3, 1, 2) set.Add(2, 3) set.Add() @@ -62,7 +62,7 @@ func TestSetContains(t *testing.T) { } func TestSetRemove(t *testing.T) { - set := New() + set := New[int]() set.Add(3, 1, 2) set.Remove() if actualValue := set.Size(); actualValue != 3 { @@ -82,7 +82,7 @@ func TestSetRemove(t *testing.T) { } func TestSetSerialization(t *testing.T) { - set := New() + set := New[string]() set.Add("a", "b", "c") var err error @@ -111,14 +111,15 @@ func TestSetSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &set) + err = json.Unmarshal([]byte(`["a","b","c"]`), &set) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestSetString(t *testing.T) { - c := New() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "HashSet") { t.Errorf("String should start with container name") @@ -126,8 +127,8 @@ func TestSetString(t *testing.T) { } func TestSetIntersection(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() intersection := set.Intersection(another) if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { @@ -148,8 +149,8 @@ func TestSetIntersection(t *testing.T) { } func TestSetUnion(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() union := set.Union(another) if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { @@ -170,8 +171,8 @@ func TestSetUnion(t *testing.T) { } func TestSetDifference(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() difference := set.Difference(another) if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { @@ -191,7 +192,7 @@ func TestSetDifference(t *testing.T) { } } -func benchmarkContains(b *testing.B, set *Set, size int) { +func benchmarkContains(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Contains(n) @@ -199,7 +200,7 @@ func benchmarkContains(b *testing.B, set *Set, size int) { } } -func benchmarkAdd(b *testing.B, set *Set, size int) { +func benchmarkAdd(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Add(n) @@ -207,7 +208,7 @@ func benchmarkAdd(b *testing.B, set *Set, size int) { } } -func benchmarkRemove(b *testing.B, set *Set, size int) { +func benchmarkRemove(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Remove(n) @@ -218,7 +219,7 @@ func benchmarkRemove(b *testing.B, set *Set, size int) { func BenchmarkHashSetContains100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -229,7 +230,7 @@ func BenchmarkHashSetContains100(b *testing.B) { func BenchmarkHashSetContains1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -240,7 +241,7 @@ func BenchmarkHashSetContains1000(b *testing.B) { func BenchmarkHashSetContains10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -251,7 +252,7 @@ func BenchmarkHashSetContains10000(b *testing.B) { func BenchmarkHashSetContains100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -262,7 +263,7 @@ func BenchmarkHashSetContains100000(b *testing.B) { func BenchmarkHashSetAdd100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() b.StartTimer() benchmarkAdd(b, set, size) } @@ -270,7 +271,7 @@ func BenchmarkHashSetAdd100(b *testing.B) { func BenchmarkHashSetAdd1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -281,7 +282,7 @@ func BenchmarkHashSetAdd1000(b *testing.B) { func BenchmarkHashSetAdd10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -292,7 +293,7 @@ func BenchmarkHashSetAdd10000(b *testing.B) { func BenchmarkHashSetAdd100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -303,7 +304,7 @@ func BenchmarkHashSetAdd100000(b *testing.B) { func BenchmarkHashSetRemove100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -314,7 +315,7 @@ func BenchmarkHashSetRemove100(b *testing.B) { func BenchmarkHashSetRemove1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -325,7 +326,7 @@ func BenchmarkHashSetRemove1000(b *testing.B) { func BenchmarkHashSetRemove10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -336,7 +337,7 @@ func BenchmarkHashSetRemove10000(b *testing.B) { func BenchmarkHashSetRemove100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } diff --git a/sets/hashset/serialization.go b/sets/hashset/serialization.go index 583d129d..4b81ce2f 100644 --- a/sets/hashset/serialization.go +++ b/sets/hashset/serialization.go @@ -6,21 +6,22 @@ package hashset import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Set)(nil) -var _ containers.JSONDeserializer = (*Set)(nil) +var _ containers.JSONSerializer = (*Set[int])(nil) +var _ containers.JSONDeserializer = (*Set[int])(nil) // ToJSON outputs the JSON representation of the set. -func (set *Set) ToJSON() ([]byte, error) { +func (set *Set[T]) ToJSON() ([]byte, error) { return json.Marshal(set.Values()) } // FromJSON populates the set from the input JSON representation. -func (set *Set) FromJSON(data []byte) error { - elements := []interface{}{} +func (set *Set[T]) FromJSON(data []byte) error { + var elements []T err := json.Unmarshal(data, &elements) if err == nil { set.Clear() @@ -30,11 +31,11 @@ func (set *Set) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (set *Set) UnmarshalJSON(bytes []byte) error { +func (set *Set[T]) UnmarshalJSON(bytes []byte) error { return set.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (set *Set) MarshalJSON() ([]byte, error) { +func (set *Set[T]) MarshalJSON() ([]byte, error) { return set.ToJSON() } diff --git a/sets/linkedhashset/enumerable.go b/sets/linkedhashset/enumerable.go index fc855722..2e51edfe 100644 --- a/sets/linkedhashset/enumerable.go +++ b/sets/linkedhashset/enumerable.go @@ -4,13 +4,13 @@ package linkedhashset -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Enumerable implementation -var _ containers.EnumerableWithIndex = (*Set)(nil) +var _ containers.EnumerableWithIndex[int] = (*Set[int])(nil) // Each calls the given function once for each element, passing that element's index and value. -func (set *Set) Each(f func(index int, value interface{})) { +func (set *Set[T]) Each(f func(index int, value T)) { iterator := set.Iterator() for iterator.Next() { f(iterator.Index(), iterator.Value()) @@ -19,8 +19,8 @@ func (set *Set) Each(f func(index int, value interface{})) { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. -func (set *Set) Map(f func(index int, value interface{}) interface{}) *Set { - newSet := New() +func (set *Set[T]) Map(f func(index int, value T) T) *Set[T] { + newSet := New[T]() iterator := set.Iterator() for iterator.Next() { newSet.Add(f(iterator.Index(), iterator.Value())) @@ -29,8 +29,8 @@ func (set *Set) Map(f func(index int, value interface{}) interface{}) *Set { } // Select returns a new container containing all elements for which the given function returns a true value. -func (set *Set) Select(f func(index int, value interface{}) bool) *Set { - newSet := New() +func (set *Set[T]) Select(f func(index int, value T) bool) *Set[T] { + newSet := New[T]() iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -42,7 +42,7 @@ func (set *Set) Select(f func(index int, value interface{}) bool) *Set { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (set *Set) Any(f func(index int, value interface{}) bool) bool { +func (set *Set[T]) Any(f func(index int, value T) bool) bool { iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -54,7 +54,7 @@ func (set *Set) Any(f func(index int, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (set *Set) All(f func(index int, value interface{}) bool) bool { +func (set *Set[T]) All(f func(index int, value T) bool) bool { iterator := set.Iterator() for iterator.Next() { if !f(iterator.Index(), iterator.Value()) { @@ -67,12 +67,13 @@ func (set *Set) All(f func(index int, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. -func (set *Set) Find(f func(index int, value interface{}) bool) (int, interface{}) { +func (set *Set[T]) Find(f func(index int, value T) bool) (int, T) { iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { return iterator.Index(), iterator.Value() } } - return -1, nil + var t T + return -1, t } diff --git a/sets/linkedhashset/iterator.go b/sets/linkedhashset/iterator.go index aa793841..378ec853 100644 --- a/sets/linkedhashset/iterator.go +++ b/sets/linkedhashset/iterator.go @@ -5,73 +5,73 @@ package linkedhashset import ( - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/lists/doublylinkedlist" + "github.com/emirpasic/gods/v2/containers" + "github.com/emirpasic/gods/v2/lists/doublylinkedlist" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator holding the iterator's state -type Iterator struct { - iterator doublylinkedlist.Iterator +type Iterator[T comparable] struct { + iterator doublylinkedlist.Iterator[T] } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (set *Set) Iterator() Iterator { - return Iterator{iterator: set.ordering.Iterator()} +func (set *Set[T]) Iterator() Iterator[T] { + return Iterator[T]{iterator: set.ordering.Iterator()} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { return iterator.iterator.Next() } // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { return iterator.iterator.Prev() } // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.iterator.Value() } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.iterator.Index() } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.iterator.End() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { return iterator.iterator.First() } // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { return iterator.iterator.Last() } @@ -79,7 +79,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -93,7 +93,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/sets/linkedhashset/linkedhashset.go b/sets/linkedhashset/linkedhashset.go index 3bf4e5fe..ea681a9f 100644 --- a/sets/linkedhashset/linkedhashset.go +++ b/sets/linkedhashset/linkedhashset.go @@ -15,27 +15,28 @@ package linkedhashset import ( "fmt" - "github.com/emirpasic/gods/lists/doublylinkedlist" - "github.com/emirpasic/gods/sets" "strings" + + "github.com/emirpasic/gods/v2/lists/doublylinkedlist" + "github.com/emirpasic/gods/v2/sets" ) // Assert Set implementation -var _ sets.Set = (*Set)(nil) +var _ sets.Set[int] = (*Set[int])(nil) // Set holds elements in go's native map -type Set struct { - table map[interface{}]struct{} - ordering *doublylinkedlist.List +type Set[T comparable] struct { + table map[T]struct{} + ordering *doublylinkedlist.List[T] } var itemExists = struct{}{} // New instantiates a new empty set and adds the passed values, if any, to the set -func New(values ...interface{}) *Set { - set := &Set{ - table: make(map[interface{}]struct{}), - ordering: doublylinkedlist.New(), +func New[T comparable](values ...T) *Set[T] { + set := &Set[T]{ + table: make(map[T]struct{}), + ordering: doublylinkedlist.New[T](), } if len(values) > 0 { set.Add(values...) @@ -45,7 +46,7 @@ func New(values ...interface{}) *Set { // Add adds the items (one or more) to the set. // Note that insertion-order is not affected if an element is re-inserted into the set. -func (set *Set) Add(items ...interface{}) { +func (set *Set[T]) Add(items ...T) { for _, item := range items { if _, contains := set.table[item]; !contains { set.table[item] = itemExists @@ -56,7 +57,7 @@ func (set *Set) Add(items ...interface{}) { // Remove removes the items (one or more) from the set. // Slow operation, worst-case O(n^2). -func (set *Set) Remove(items ...interface{}) { +func (set *Set[T]) Remove(items ...T) { for _, item := range items { if _, contains := set.table[item]; contains { delete(set.table, item) @@ -69,7 +70,7 @@ func (set *Set) Remove(items ...interface{}) { // Contains check if items (one or more) are present in the set. // All items have to be present in the set for the method to return true. // Returns true if no arguments are passed at all, i.e. set is always superset of empty set. -func (set *Set) Contains(items ...interface{}) bool { +func (set *Set[T]) Contains(items ...T) bool { for _, item := range items { if _, contains := set.table[item]; !contains { return false @@ -79,24 +80,24 @@ func (set *Set) Contains(items ...interface{}) bool { } // Empty returns true if set does not contain any elements. -func (set *Set) Empty() bool { +func (set *Set[T]) Empty() bool { return set.Size() == 0 } // Size returns number of elements within the set. -func (set *Set) Size() int { +func (set *Set[T]) Size() int { return set.ordering.Size() } // Clear clears all values in the set. -func (set *Set) Clear() { - set.table = make(map[interface{}]struct{}) +func (set *Set[T]) Clear() { + set.table = make(map[T]struct{}) set.ordering.Clear() } // Values returns all items in the set. -func (set *Set) Values() []interface{} { - values := make([]interface{}, set.Size()) +func (set *Set[T]) Values() []T { + values := make([]T, set.Size()) it := set.Iterator() for it.Next() { values[it.Index()] = it.Value() @@ -105,7 +106,7 @@ func (set *Set) Values() []interface{} { } // String returns a string representation of container -func (set *Set) String() string { +func (set *Set[T]) String() string { str := "LinkedHashSet\n" items := []string{} it := set.Iterator() @@ -119,8 +120,8 @@ func (set *Set) String() string { // Intersection returns the intersection between two sets. // The new set consists of all elements that are both in "set" and "another". // Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) -func (set *Set) Intersection(another *Set) *Set { - result := New() +func (set *Set[T]) Intersection(another *Set[T]) *Set[T] { + result := New[T]() // Iterate over smaller set (optimization) if set.Size() <= another.Size() { @@ -143,8 +144,8 @@ func (set *Set) Intersection(another *Set) *Set { // Union returns the union of two sets. // The new set consists of all elements that are in "set" or "another" (possibly both). // Ref: https://en.wikipedia.org/wiki/Union_(set_theory) -func (set *Set) Union(another *Set) *Set { - result := New() +func (set *Set[T]) Union(another *Set[T]) *Set[T] { + result := New[T]() for item := range set.table { result.Add(item) @@ -159,8 +160,8 @@ func (set *Set) Union(another *Set) *Set { // Difference returns the difference between two sets. // The new set consists of all elements that are in "set" but not in "another". // Ref: https://proofwiki.org/wiki/Definition:Set_Difference -func (set *Set) Difference(another *Set) *Set { - result := New() +func (set *Set[T]) Difference(another *Set[T]) *Set[T] { + result := New[T]() for item := range set.table { if _, contains := another.table[item]; !contains { diff --git a/sets/linkedhashset/linkedhashset_test.go b/sets/linkedhashset/linkedhashset_test.go index 3f857ea9..a3a8fac4 100644 --- a/sets/linkedhashset/linkedhashset_test.go +++ b/sets/linkedhashset/linkedhashset_test.go @@ -13,20 +13,23 @@ import ( func TestSetNew(t *testing.T) { set := New(2, 1) + if actualValue := set.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) } - values := set.Values() - if actualValue := values[0]; actualValue != 2 { - t.Errorf("Got %v expected %v", actualValue, 2) + if actualValue := set.Contains(1); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) } - if actualValue := values[1]; actualValue != 1 { - t.Errorf("Got %v expected %v", actualValue, 1) + if actualValue := set.Contains(2); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } + if actualValue := set.Contains(3); actualValue != false { + t.Errorf("Got %v expected %v", actualValue, true) } } func TestSetAdd(t *testing.T) { - set := New() + set := New[int]() set.Add() set.Add(1) set.Add(2) @@ -41,7 +44,7 @@ func TestSetAdd(t *testing.T) { } func TestSetContains(t *testing.T) { - set := New() + set := New[int]() set.Add(3, 1, 2) set.Add(2, 3) set.Add() @@ -60,7 +63,7 @@ func TestSetContains(t *testing.T) { } func TestSetRemove(t *testing.T) { - set := New() + set := New[int]() set.Add(3, 1, 2) set.Remove() if actualValue := set.Size(); actualValue != 3 { @@ -80,9 +83,9 @@ func TestSetRemove(t *testing.T) { } func TestSetEach(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - set.Each(func(index int, value interface{}) { + set.Each(func(index int, value string) { switch index { case 0: if actualValue, expectedValue := value, "c"; actualValue != expectedValue { @@ -103,10 +106,10 @@ func TestSetEach(t *testing.T) { } func TestSetMap(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - mappedSet := set.Map(func(index int, value interface{}) interface{} { - return "mapped: " + value.(string) + mappedSet := set.Map(func(index int, value string) string { + return "mapped: " + value }) if actualValue, expectedValue := mappedSet.Contains("mapped: c", "mapped: b", "mapped: a"), true; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -120,10 +123,10 @@ func TestSetMap(t *testing.T) { } func TestSetSelect(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - selectedSet := set.Select(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + selectedSet := set.Select(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if actualValue, expectedValue := selectedSet.Contains("a", "b"), true; actualValue != expectedValue { fmt.Println("A: ", selectedSet.Contains("b")) @@ -138,16 +141,16 @@ func TestSetSelect(t *testing.T) { } func TestSetAny(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - any := set.Any(func(index int, value interface{}) bool { - return value.(string) == "c" + any := set.Any(func(index int, value string) bool { + return value == "c" }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = set.Any(func(index int, value interface{}) bool { - return value.(string) == "x" + any = set.Any(func(index int, value string) bool { + return value == "x" }) if any != false { t.Errorf("Got %v expected %v", any, false) @@ -155,16 +158,16 @@ func TestSetAny(t *testing.T) { } func TestSetAll(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - all := set.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "c" + all := set.All(func(index int, value string) bool { + return value >= "a" && value <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = set.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + all = set.All(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) @@ -172,29 +175,29 @@ func TestSetAll(t *testing.T) { } func TestSetFind(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") - foundIndex, foundValue := set.Find(func(index int, value interface{}) bool { - return value.(string) == "c" + foundIndex, foundValue := set.Find(func(index int, value string) bool { + return value == "c" }) if foundValue != "c" || foundIndex != 0 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 0) } - foundIndex, foundValue = set.Find(func(index int, value interface{}) bool { - return value.(string) == "x" + foundIndex, foundValue = set.Find(func(index int, value string) bool { + return value == "x" }) - if foundValue != nil || foundIndex != -1 { + if foundValue != "" || foundIndex != -1 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) } } func TestSetChaining(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") } func TestSetIteratorPrevOnEmpty(t *testing.T) { - set := New() + set := New[string]() it := set.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty set") @@ -202,7 +205,7 @@ func TestSetIteratorPrevOnEmpty(t *testing.T) { } func TestSetIteratorNext(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") it := set.Iterator() count := 0 @@ -236,7 +239,7 @@ func TestSetIteratorNext(t *testing.T) { } func TestSetIteratorPrev(t *testing.T) { - set := New() + set := New[string]() set.Add("c", "a", "b") it := set.Iterator() for it.Prev() { @@ -272,7 +275,7 @@ func TestSetIteratorPrev(t *testing.T) { } func TestSetIteratorBegin(t *testing.T) { - set := New() + set := New[string]() it := set.Iterator() it.Begin() set.Add("a", "b", "c") @@ -286,7 +289,7 @@ func TestSetIteratorBegin(t *testing.T) { } func TestSetIteratorEnd(t *testing.T) { - set := New() + set := New[string]() it := set.Iterator() if index := it.Index(); index != -1 { @@ -311,7 +314,7 @@ func TestSetIteratorEnd(t *testing.T) { } func TestSetIteratorFirst(t *testing.T) { - set := New() + set := New[string]() set.Add("a", "b", "c") it := set.Iterator() if actualValue, expectedValue := it.First(), true; actualValue != expectedValue { @@ -323,7 +326,7 @@ func TestSetIteratorFirst(t *testing.T) { } func TestSetIteratorLast(t *testing.T) { - set := New() + set := New[string]() set.Add("a", "b", "c") it := set.Iterator() if actualValue, expectedValue := it.Last(), true; actualValue != expectedValue { @@ -336,13 +339,13 @@ func TestSetIteratorLast(t *testing.T) { func TestSetIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - set := New() + set := New[string]() it := set.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty set") @@ -351,7 +354,7 @@ func TestSetIteratorNextTo(t *testing.T) { // NextTo (not found) { - set := New() + set := New[string]() set.Add("xx", "yy") it := set.Iterator() for it.NextTo(seek) { @@ -361,20 +364,20 @@ func TestSetIteratorNextTo(t *testing.T) { // NextTo (found) { - set := New() + set := New[string]() set.Add("aa", "bb", "cc") it := set.Iterator() it.Begin() if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty set") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -385,13 +388,13 @@ func TestSetIteratorNextTo(t *testing.T) { func TestSetIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - set := New() + set := New[string]() it := set.Iterator() it.End() for it.PrevTo(seek) { @@ -401,7 +404,7 @@ func TestSetIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - set := New() + set := New[string]() set.Add("xx", "yy") it := set.Iterator() it.End() @@ -412,20 +415,20 @@ func TestSetIteratorPrevTo(t *testing.T) { // PrevTo (found) { - set := New() + set := New[string]() set.Add("aa", "bb", "cc") it := set.Iterator() it.End() if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty set") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -435,7 +438,7 @@ func TestSetIteratorPrevTo(t *testing.T) { } func TestSetSerialization(t *testing.T) { - set := New() + set := New[string]() set.Add("a", "b", "c") var err error @@ -464,14 +467,15 @@ func TestSetSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &set) + err = json.Unmarshal([]byte(`["a","b","c"]`), &set) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestSetString(t *testing.T) { - c := New() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "LinkedHashSet") { t.Errorf("String should start with container name") @@ -479,8 +483,8 @@ func TestSetString(t *testing.T) { } func TestSetIntersection(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() intersection := set.Intersection(another) if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { @@ -501,8 +505,8 @@ func TestSetIntersection(t *testing.T) { } func TestSetUnion(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() union := set.Union(another) if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { @@ -523,8 +527,8 @@ func TestSetUnion(t *testing.T) { } func TestSetDifference(t *testing.T) { - set := New() - another := New() + set := New[string]() + another := New[string]() difference := set.Difference(another) if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { @@ -544,7 +548,7 @@ func TestSetDifference(t *testing.T) { } } -func benchmarkContains(b *testing.B, set *Set, size int) { +func benchmarkContains(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Contains(n) @@ -552,7 +556,7 @@ func benchmarkContains(b *testing.B, set *Set, size int) { } } -func benchmarkAdd(b *testing.B, set *Set, size int) { +func benchmarkAdd(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Add(n) @@ -560,7 +564,7 @@ func benchmarkAdd(b *testing.B, set *Set, size int) { } } -func benchmarkRemove(b *testing.B, set *Set, size int) { +func benchmarkRemove(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Remove(n) @@ -571,7 +575,7 @@ func benchmarkRemove(b *testing.B, set *Set, size int) { func BenchmarkHashSetContains100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -582,7 +586,7 @@ func BenchmarkHashSetContains100(b *testing.B) { func BenchmarkHashSetContains1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -593,7 +597,7 @@ func BenchmarkHashSetContains1000(b *testing.B) { func BenchmarkHashSetContains10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -604,7 +608,7 @@ func BenchmarkHashSetContains10000(b *testing.B) { func BenchmarkHashSetContains100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -615,7 +619,7 @@ func BenchmarkHashSetContains100000(b *testing.B) { func BenchmarkHashSetAdd100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() b.StartTimer() benchmarkAdd(b, set, size) } @@ -623,7 +627,7 @@ func BenchmarkHashSetAdd100(b *testing.B) { func BenchmarkHashSetAdd1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -634,7 +638,7 @@ func BenchmarkHashSetAdd1000(b *testing.B) { func BenchmarkHashSetAdd10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -645,7 +649,7 @@ func BenchmarkHashSetAdd10000(b *testing.B) { func BenchmarkHashSetAdd100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -656,7 +660,7 @@ func BenchmarkHashSetAdd100000(b *testing.B) { func BenchmarkHashSetRemove100(b *testing.B) { b.StopTimer() size := 100 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -667,7 +671,7 @@ func BenchmarkHashSetRemove100(b *testing.B) { func BenchmarkHashSetRemove1000(b *testing.B) { b.StopTimer() size := 1000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -678,7 +682,7 @@ func BenchmarkHashSetRemove1000(b *testing.B) { func BenchmarkHashSetRemove10000(b *testing.B) { b.StopTimer() size := 10000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -689,7 +693,7 @@ func BenchmarkHashSetRemove10000(b *testing.B) { func BenchmarkHashSetRemove100000(b *testing.B) { b.StopTimer() size := 100000 - set := New() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } diff --git a/sets/linkedhashset/serialization.go b/sets/linkedhashset/serialization.go index ab2f3b4d..c29e1cbd 100644 --- a/sets/linkedhashset/serialization.go +++ b/sets/linkedhashset/serialization.go @@ -6,21 +6,22 @@ package linkedhashset import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Set)(nil) -var _ containers.JSONDeserializer = (*Set)(nil) +var _ containers.JSONSerializer = (*Set[int])(nil) +var _ containers.JSONDeserializer = (*Set[int])(nil) // ToJSON outputs the JSON representation of the set. -func (set *Set) ToJSON() ([]byte, error) { +func (set *Set[T]) ToJSON() ([]byte, error) { return json.Marshal(set.Values()) } // FromJSON populates the set from the input JSON representation. -func (set *Set) FromJSON(data []byte) error { - elements := []interface{}{} +func (set *Set[T]) FromJSON(data []byte) error { + var elements []T err := json.Unmarshal(data, &elements) if err == nil { set.Clear() @@ -30,11 +31,11 @@ func (set *Set) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (set *Set) UnmarshalJSON(bytes []byte) error { +func (set *Set[T]) UnmarshalJSON(bytes []byte) error { return set.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (set *Set) MarshalJSON() ([]byte, error) { +func (set *Set[T]) MarshalJSON() ([]byte, error) { return set.ToJSON() } diff --git a/sets/sets.go b/sets/sets.go index 9641951a..e63bb5fe 100644 --- a/sets/sets.go +++ b/sets/sets.go @@ -9,18 +9,17 @@ // Reference: https://en.wikipedia.org/wiki/Set_%28abstract_data_type%29 package sets -import "github.com/emirpasic/gods/containers" +import ( + "github.com/emirpasic/gods/v2/containers" +) // Set interface that all sets implement -type Set interface { - Add(elements ...interface{}) - Remove(elements ...interface{}) - Contains(elements ...interface{}) bool - // Intersection(another *Set) *Set - // Union(another *Set) *Set - // Difference(another *Set) *Set +type Set[T comparable] interface { + Add(elements ...T) + Remove(elements ...T) + Contains(elements ...T) bool - containers.Container + containers.Container[T] // Empty() bool // Size() int // Clear() diff --git a/sets/treeset/enumerable.go b/sets/treeset/enumerable.go index c774834a..b41a495a 100644 --- a/sets/treeset/enumerable.go +++ b/sets/treeset/enumerable.go @@ -5,15 +5,15 @@ package treeset import ( - "github.com/emirpasic/gods/containers" - rbt "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/v2/containers" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // Assert Enumerable implementation -var _ containers.EnumerableWithIndex = (*Set)(nil) +var _ containers.EnumerableWithIndex[int] = (*Set[int])(nil) // Each calls the given function once for each element, passing that element's index and value. -func (set *Set) Each(f func(index int, value interface{})) { +func (set *Set[T]) Each(f func(index int, value T)) { iterator := set.Iterator() for iterator.Next() { f(iterator.Index(), iterator.Value()) @@ -22,8 +22,8 @@ func (set *Set) Each(f func(index int, value interface{})) { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. -func (set *Set) Map(f func(index int, value interface{}) interface{}) *Set { - newSet := &Set{tree: rbt.NewWith(set.tree.Comparator)} +func (set *Set[T]) Map(f func(index int, value T) T) *Set[T] { + newSet := &Set[T]{tree: rbt.NewWith[T, struct{}](set.tree.Comparator)} iterator := set.Iterator() for iterator.Next() { newSet.Add(f(iterator.Index(), iterator.Value())) @@ -32,8 +32,8 @@ func (set *Set) Map(f func(index int, value interface{}) interface{}) *Set { } // Select returns a new container containing all elements for which the given function returns a true value. -func (set *Set) Select(f func(index int, value interface{}) bool) *Set { - newSet := &Set{tree: rbt.NewWith(set.tree.Comparator)} +func (set *Set[T]) Select(f func(index int, value T) bool) *Set[T] { + newSet := &Set[T]{tree: rbt.NewWith[T, struct{}](set.tree.Comparator)} iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -45,7 +45,7 @@ func (set *Set) Select(f func(index int, value interface{}) bool) *Set { // Any passes each element of the container to the given function and // returns true if the function ever returns true for any element. -func (set *Set) Any(f func(index int, value interface{}) bool) bool { +func (set *Set[T]) Any(f func(index int, value T) bool) bool { iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { @@ -57,7 +57,7 @@ func (set *Set) Any(f func(index int, value interface{}) bool) bool { // All passes each element of the container to the given function and // returns true if the function returns true for all elements. -func (set *Set) All(f func(index int, value interface{}) bool) bool { +func (set *Set[T]) All(f func(index int, value T) bool) bool { iterator := set.Iterator() for iterator.Next() { if !f(iterator.Index(), iterator.Value()) { @@ -70,12 +70,13 @@ func (set *Set) All(f func(index int, value interface{}) bool) bool { // Find passes each element of the container to the given function and returns // the first (index,value) for which the function is true or -1,nil otherwise // if no element matches the criteria. -func (set *Set) Find(f func(index int, value interface{}) bool) (int, interface{}) { +func (set *Set[T]) Find(f func(index int, value T) bool) (int, T) { iterator := set.Iterator() for iterator.Next() { if f(iterator.Index(), iterator.Value()) { return iterator.Index(), iterator.Value() } } - return -1, nil + var t T + return -1, t } diff --git a/sets/treeset/iterator.go b/sets/treeset/iterator.go index 88a0bea7..435b46b0 100644 --- a/sets/treeset/iterator.go +++ b/sets/treeset/iterator.go @@ -5,30 +5,30 @@ package treeset import ( - "github.com/emirpasic/gods/containers" - rbt "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/v2/containers" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { +type Iterator[T comparable] struct { index int - iterator rbt.Iterator - tree *rbt.Tree + iterator *rbt.Iterator[T, struct{}] + tree *rbt.Tree[T, struct{}] } // Iterator holding the iterator's state -func (set *Set) Iterator() Iterator { - return Iterator{index: -1, iterator: set.tree.Iterator(), tree: set.tree} +func (set *Set[T]) Iterator() Iterator[T] { + return Iterator[T]{index: -1, iterator: set.tree.Iterator(), tree: set.tree} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.tree.Size() { iterator.index++ } @@ -38,7 +38,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -47,26 +47,26 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { return iterator.iterator.Key() } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 iterator.iterator.Begin() } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.tree.Size() iterator.iterator.End() } @@ -74,7 +74,7 @@ func (iterator *Iterator) End() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -82,7 +82,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -91,7 +91,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -105,7 +105,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/sets/treeset/serialization.go b/sets/treeset/serialization.go index 76d049dd..5543fd36 100644 --- a/sets/treeset/serialization.go +++ b/sets/treeset/serialization.go @@ -6,21 +6,22 @@ package treeset import ( "encoding/json" - "github.com/emirpasic/gods/containers" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Set)(nil) -var _ containers.JSONDeserializer = (*Set)(nil) +var _ containers.JSONSerializer = (*Set[int])(nil) +var _ containers.JSONDeserializer = (*Set[int])(nil) // ToJSON outputs the JSON representation of the set. -func (set *Set) ToJSON() ([]byte, error) { +func (set *Set[T]) ToJSON() ([]byte, error) { return json.Marshal(set.Values()) } // FromJSON populates the set from the input JSON representation. -func (set *Set) FromJSON(data []byte) error { - elements := []interface{}{} +func (set *Set[T]) FromJSON(data []byte) error { + var elements []T err := json.Unmarshal(data, &elements) if err == nil { set.Clear() @@ -30,11 +31,11 @@ func (set *Set) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (set *Set) UnmarshalJSON(bytes []byte) error { +func (set *Set[T]) UnmarshalJSON(bytes []byte) error { return set.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (set *Set) MarshalJSON() ([]byte, error) { +func (set *Set[T]) MarshalJSON() ([]byte, error) { return set.ToJSON() } diff --git a/sets/treeset/treeset.go b/sets/treeset/treeset.go index 3507cc90..6b63a7e1 100644 --- a/sets/treeset/treeset.go +++ b/sets/treeset/treeset.go @@ -10,45 +10,33 @@ package treeset import ( + "cmp" "fmt" - "github.com/emirpasic/gods/sets" - rbt "github.com/emirpasic/gods/trees/redblacktree" - "github.com/emirpasic/gods/utils" "reflect" "strings" + + "github.com/emirpasic/gods/v2/sets" + rbt "github.com/emirpasic/gods/v2/trees/redblacktree" + "github.com/emirpasic/gods/v2/utils" ) // Assert Set implementation -var _ sets.Set = (*Set)(nil) +var _ sets.Set[int] = (*Set[int])(nil) // Set holds elements in a red-black tree -type Set struct { - tree *rbt.Tree +type Set[T comparable] struct { + tree *rbt.Tree[T, struct{}] } var itemExists = struct{}{} -// NewWith instantiates a new empty set with the custom comparator. -func NewWith(comparator utils.Comparator, values ...interface{}) *Set { - set := &Set{tree: rbt.NewWith(comparator)} - if len(values) > 0 { - set.Add(values...) - } - return set +func New[T cmp.Ordered](values ...T) *Set[T] { + return NewWith[T](cmp.Compare[T], values...) } -// NewWithIntComparator instantiates a new empty set with the IntComparator, i.e. keys are of type int. -func NewWithIntComparator(values ...interface{}) *Set { - set := &Set{tree: rbt.NewWithIntComparator()} - if len(values) > 0 { - set.Add(values...) - } - return set -} - -// NewWithStringComparator instantiates a new empty set with the StringComparator, i.e. keys are of type string. -func NewWithStringComparator(values ...interface{}) *Set { - set := &Set{tree: rbt.NewWithStringComparator()} +// NewWith instantiates a new empty set with the custom comparator. +func NewWith[T comparable](comparator utils.Comparator[T], values ...T) *Set[T] { + set := &Set[T]{tree: rbt.NewWith[T, struct{}](comparator)} if len(values) > 0 { set.Add(values...) } @@ -56,14 +44,14 @@ func NewWithStringComparator(values ...interface{}) *Set { } // Add adds the items (one or more) to the set. -func (set *Set) Add(items ...interface{}) { +func (set *Set[T]) Add(items ...T) { for _, item := range items { set.tree.Put(item, itemExists) } } // Remove removes the items (one or more) from the set. -func (set *Set) Remove(items ...interface{}) { +func (set *Set[T]) Remove(items ...T) { for _, item := range items { set.tree.Remove(item) } @@ -72,7 +60,7 @@ func (set *Set) Remove(items ...interface{}) { // Contains checks weather items (one or more) are present in the set. // All items have to be present in the set for the method to return true. // Returns true if no arguments are passed at all, i.e. set is always superset of empty set. -func (set *Set) Contains(items ...interface{}) bool { +func (set *Set[T]) Contains(items ...T) bool { for _, item := range items { if _, contains := set.tree.Get(item); !contains { return false @@ -82,27 +70,27 @@ func (set *Set) Contains(items ...interface{}) bool { } // Empty returns true if set does not contain any elements. -func (set *Set) Empty() bool { +func (set *Set[T]) Empty() bool { return set.tree.Size() == 0 } // Size returns number of elements within the set. -func (set *Set) Size() int { +func (set *Set[T]) Size() int { return set.tree.Size() } // Clear clears all values in the set. -func (set *Set) Clear() { +func (set *Set[T]) Clear() { set.tree.Clear() } // Values returns all items in the set. -func (set *Set) Values() []interface{} { +func (set *Set[T]) Values() []T { return set.tree.Keys() } // String returns a string representation of container -func (set *Set) String() string { +func (set *Set[T]) String() string { str := "TreeSet\n" items := []string{} for _, v := range set.tree.Keys() { @@ -116,7 +104,7 @@ func (set *Set) String() string { // The new set consists of all elements that are both in "set" and "another". // The two sets should have the same comparators, otherwise the result is empty set. // Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) -func (set *Set) Intersection(another *Set) *Set { +func (set *Set[T]) Intersection(another *Set[T]) *Set[T] { result := NewWith(set.tree.Comparator) setComparator := reflect.ValueOf(set.tree.Comparator) @@ -147,7 +135,7 @@ func (set *Set) Intersection(another *Set) *Set { // The new set consists of all elements that are in "set" or "another" (possibly both). // The two sets should have the same comparators, otherwise the result is empty set. // Ref: https://en.wikipedia.org/wiki/Union_(set_theory) -func (set *Set) Union(another *Set) *Set { +func (set *Set[T]) Union(another *Set[T]) *Set[T] { result := NewWith(set.tree.Comparator) setComparator := reflect.ValueOf(set.tree.Comparator) @@ -170,7 +158,7 @@ func (set *Set) Union(another *Set) *Set { // The two sets should have the same comparators, otherwise the result is empty set. // The new set consists of all elements that are in "set" but not in "another". // Ref: https://proofwiki.org/wiki/Definition:Set_Difference -func (set *Set) Difference(another *Set) *Set { +func (set *Set[T]) Difference(another *Set[T]) *Set[T] { result := NewWith(set.tree.Comparator) setComparator := reflect.ValueOf(set.tree.Comparator) diff --git a/sets/treeset/treeset_test.go b/sets/treeset/treeset_test.go index 9e7c6707..709ed84d 100644 --- a/sets/treeset/treeset_test.go +++ b/sets/treeset/treeset_test.go @@ -9,10 +9,12 @@ import ( "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestSetNew(t *testing.T) { - set := NewWithIntComparator(2, 1) + set := New[int](2, 1) if actualValue := set.Size(); actualValue != 2 { t.Errorf("Got %v expected %v", actualValue, 2) } @@ -26,7 +28,7 @@ func TestSetNew(t *testing.T) { } func TestSetAdd(t *testing.T) { - set := NewWithIntComparator() + set := New[int]() set.Add() set.Add(1) set.Add(2) @@ -38,13 +40,11 @@ func TestSetAdd(t *testing.T) { if actualValue := set.Size(); actualValue != 3 { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, expectedValue := fmt.Sprintf("%d%d%d", set.Values()...), "123"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, set.Values(), []int{1, 2, 3}) } func TestSetContains(t *testing.T) { - set := NewWithIntComparator() + set := New[int]() set.Add(3, 1, 2) if actualValue := set.Contains(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) @@ -61,7 +61,7 @@ func TestSetContains(t *testing.T) { } func TestSetRemove(t *testing.T) { - set := NewWithIntComparator() + set := New[int]() set.Add(3, 1, 2) set.Remove() if actualValue := set.Size(); actualValue != 3 { @@ -81,9 +81,9 @@ func TestSetRemove(t *testing.T) { } func TestSetEach(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - set.Each(func(index int, value interface{}) { + set.Each(func(index int, value string) { switch index { case 0: if actualValue, expectedValue := value, "a"; actualValue != expectedValue { @@ -104,10 +104,10 @@ func TestSetEach(t *testing.T) { } func TestSetMap(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - mappedSet := set.Map(func(index int, value interface{}) interface{} { - return "mapped: " + value.(string) + mappedSet := set.Map(func(index int, value string) string { + return "mapped: " + value }) if actualValue, expectedValue := mappedSet.Contains("mapped: a", "mapped: b", "mapped: c"), true; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -121,10 +121,10 @@ func TestSetMap(t *testing.T) { } func TestSetSelect(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - selectedSet := set.Select(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + selectedSet := set.Select(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if actualValue, expectedValue := selectedSet.Contains("a", "b"), true; actualValue != expectedValue { fmt.Println("A: ", selectedSet.Contains("b")) @@ -139,16 +139,16 @@ func TestSetSelect(t *testing.T) { } func TestSetAny(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - any := set.Any(func(index int, value interface{}) bool { - return value.(string) == "c" + any := set.Any(func(index int, value string) bool { + return value == "c" }) if any != true { t.Errorf("Got %v expected %v", any, true) } - any = set.Any(func(index int, value interface{}) bool { - return value.(string) == "x" + any = set.Any(func(index int, value string) bool { + return value == "x" }) if any != false { t.Errorf("Got %v expected %v", any, false) @@ -156,16 +156,16 @@ func TestSetAny(t *testing.T) { } func TestSetAll(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - all := set.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "c" + all := set.All(func(index int, value string) bool { + return value >= "a" && value <= "c" }) if all != true { t.Errorf("Got %v expected %v", all, true) } - all = set.All(func(index int, value interface{}) bool { - return value.(string) >= "a" && value.(string) <= "b" + all = set.All(func(index int, value string) bool { + return value >= "a" && value <= "b" }) if all != false { t.Errorf("Got %v expected %v", all, false) @@ -173,29 +173,29 @@ func TestSetAll(t *testing.T) { } func TestSetFind(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") - foundIndex, foundValue := set.Find(func(index int, value interface{}) bool { - return value.(string) == "c" + foundIndex, foundValue := set.Find(func(index int, value string) bool { + return value == "c" }) if foundValue != "c" || foundIndex != 2 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 2) } - foundIndex, foundValue = set.Find(func(index int, value interface{}) bool { - return value.(string) == "x" + foundIndex, foundValue = set.Find(func(index int, value string) bool { + return value == "x" }) - if foundValue != nil || foundIndex != -1 { + if foundValue != "" || foundIndex != -1 { t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) } } func TestSetChaining(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") } func TestSetIteratorNextOnEmpty(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty set") @@ -203,7 +203,7 @@ func TestSetIteratorNextOnEmpty(t *testing.T) { } func TestSetIteratorPrevOnEmpty(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty set") @@ -211,7 +211,7 @@ func TestSetIteratorPrevOnEmpty(t *testing.T) { } func TestSetIteratorNext(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") it := set.Iterator() count := 0 @@ -245,7 +245,7 @@ func TestSetIteratorNext(t *testing.T) { } func TestSetIteratorPrev(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("c", "a", "b") it := set.Iterator() for it.Prev() { @@ -281,7 +281,7 @@ func TestSetIteratorPrev(t *testing.T) { } func TestSetIteratorBegin(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() it.Begin() set.Add("a", "b", "c") @@ -295,7 +295,7 @@ func TestSetIteratorBegin(t *testing.T) { } func TestSetIteratorEnd(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() if index := it.Index(); index != -1 { @@ -320,7 +320,7 @@ func TestSetIteratorEnd(t *testing.T) { } func TestSetIteratorFirst(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("a", "b", "c") it := set.Iterator() if actualValue, expectedValue := it.First(), true; actualValue != expectedValue { @@ -332,7 +332,7 @@ func TestSetIteratorFirst(t *testing.T) { } func TestSetIteratorLast(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("a", "b", "c") it := set.Iterator() if actualValue, expectedValue := it.Last(), true; actualValue != expectedValue { @@ -345,13 +345,13 @@ func TestSetIteratorLast(t *testing.T) { func TestSetIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty set") @@ -360,7 +360,7 @@ func TestSetIteratorNextTo(t *testing.T) { // NextTo (not found) { - set := NewWithStringComparator() + set := New[string]() set.Add("xx", "yy") it := set.Iterator() for it.NextTo(seek) { @@ -370,20 +370,20 @@ func TestSetIteratorNextTo(t *testing.T) { // NextTo (found) { - set := NewWithStringComparator() + set := New[string]() set.Add("aa", "bb", "cc") it := set.Iterator() it.Begin() if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty set") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -394,13 +394,13 @@ func TestSetIteratorNextTo(t *testing.T) { func TestSetIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - set := NewWithStringComparator() + set := New[string]() it := set.Iterator() it.End() for it.PrevTo(seek) { @@ -410,7 +410,7 @@ func TestSetIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - set := NewWithStringComparator() + set := New[string]() set.Add("xx", "yy") it := set.Iterator() it.End() @@ -421,20 +421,20 @@ func TestSetIteratorPrevTo(t *testing.T) { // PrevTo (found) { - set := NewWithStringComparator() + set := New[string]() set.Add("aa", "bb", "cc") it := set.Iterator() it.End() if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty set") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -444,7 +444,7 @@ func TestSetIteratorPrevTo(t *testing.T) { } func TestSetSerialization(t *testing.T) { - set := NewWithStringComparator() + set := New[string]() set.Add("a", "b", "c") var err error @@ -480,7 +480,7 @@ func TestSetSerialization(t *testing.T) { } func TestSetString(t *testing.T) { - c := NewWithIntComparator() + c := New[int]() c.Add(1) if !strings.HasPrefix(c.String(), "TreeSet") { t.Errorf("String should start with container name") @@ -488,19 +488,8 @@ func TestSetString(t *testing.T) { } func TestSetIntersection(t *testing.T) { - { - set := NewWithStringComparator() - another := NewWithIntComparator() - set.Add("a", "b", "c", "d") - another.Add(1, 2, 3, 4) - difference := set.Difference(another) - if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - } - - set := NewWithStringComparator() - another := NewWithStringComparator() + set := New[string]() + another := New[string]() intersection := set.Intersection(another) if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { @@ -521,19 +510,8 @@ func TestSetIntersection(t *testing.T) { } func TestSetUnion(t *testing.T) { - { - set := NewWithStringComparator() - another := NewWithIntComparator() - set.Add("a", "b", "c", "d") - another.Add(1, 2, 3, 4) - difference := set.Difference(another) - if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - } - - set := NewWithStringComparator() - another := NewWithStringComparator() + set := New[string]() + another := New[string]() union := set.Union(another) if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { @@ -554,19 +532,8 @@ func TestSetUnion(t *testing.T) { } func TestSetDifference(t *testing.T) { - { - set := NewWithStringComparator() - another := NewWithIntComparator() - set.Add("a", "b", "c", "d") - another.Add(1, 2, 3, 4) - difference := set.Difference(another) - if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - } - - set := NewWithStringComparator() - another := NewWithStringComparator() + set := New[string]() + another := New[string]() difference := set.Difference(another) if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { @@ -586,7 +553,7 @@ func TestSetDifference(t *testing.T) { } } -func benchmarkContains(b *testing.B, set *Set, size int) { +func benchmarkContains(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Contains(n) @@ -594,7 +561,7 @@ func benchmarkContains(b *testing.B, set *Set, size int) { } } -func benchmarkAdd(b *testing.B, set *Set, size int) { +func benchmarkAdd(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Add(n) @@ -602,7 +569,7 @@ func benchmarkAdd(b *testing.B, set *Set, size int) { } } -func benchmarkRemove(b *testing.B, set *Set, size int) { +func benchmarkRemove(b *testing.B, set *Set[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { set.Remove(n) @@ -613,7 +580,7 @@ func benchmarkRemove(b *testing.B, set *Set, size int) { func BenchmarkTreeSetContains100(b *testing.B) { b.StopTimer() size := 100 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -624,7 +591,7 @@ func BenchmarkTreeSetContains100(b *testing.B) { func BenchmarkTreeSetContains1000(b *testing.B) { b.StopTimer() size := 1000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -635,7 +602,7 @@ func BenchmarkTreeSetContains1000(b *testing.B) { func BenchmarkTreeSetContains10000(b *testing.B) { b.StopTimer() size := 10000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -646,7 +613,7 @@ func BenchmarkTreeSetContains10000(b *testing.B) { func BenchmarkTreeSetContains100000(b *testing.B) { b.StopTimer() size := 100000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -657,7 +624,7 @@ func BenchmarkTreeSetContains100000(b *testing.B) { func BenchmarkTreeSetAdd100(b *testing.B) { b.StopTimer() size := 100 - set := NewWithIntComparator() + set := New[int]() b.StartTimer() benchmarkAdd(b, set, size) } @@ -665,7 +632,7 @@ func BenchmarkTreeSetAdd100(b *testing.B) { func BenchmarkTreeSetAdd1000(b *testing.B) { b.StopTimer() size := 1000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -676,7 +643,7 @@ func BenchmarkTreeSetAdd1000(b *testing.B) { func BenchmarkTreeSetAdd10000(b *testing.B) { b.StopTimer() size := 10000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -687,7 +654,7 @@ func BenchmarkTreeSetAdd10000(b *testing.B) { func BenchmarkTreeSetAdd100000(b *testing.B) { b.StopTimer() size := 100000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -698,7 +665,7 @@ func BenchmarkTreeSetAdd100000(b *testing.B) { func BenchmarkTreeSetRemove100(b *testing.B) { b.StopTimer() size := 100 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -709,7 +676,7 @@ func BenchmarkTreeSetRemove100(b *testing.B) { func BenchmarkTreeSetRemove1000(b *testing.B) { b.StopTimer() size := 1000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -720,7 +687,7 @@ func BenchmarkTreeSetRemove1000(b *testing.B) { func BenchmarkTreeSetRemove10000(b *testing.B) { b.StopTimer() size := 10000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } @@ -731,7 +698,7 @@ func BenchmarkTreeSetRemove10000(b *testing.B) { func BenchmarkTreeSetRemove100000(b *testing.B) { b.StopTimer() size := 100000 - set := NewWithIntComparator() + set := New[int]() for n := 0; n < size; n++ { set.Add(n) } diff --git a/stacks/arraystack/arraystack.go b/stacks/arraystack/arraystack.go index 78c3dda3..ec80cb20 100644 --- a/stacks/arraystack/arraystack.go +++ b/stacks/arraystack/arraystack.go @@ -11,32 +11,33 @@ package arraystack import ( "fmt" - "github.com/emirpasic/gods/lists/arraylist" - "github.com/emirpasic/gods/stacks" "strings" + + "github.com/emirpasic/gods/v2/lists/arraylist" + "github.com/emirpasic/gods/v2/stacks" ) // Assert Stack implementation -var _ stacks.Stack = (*Stack)(nil) +var _ stacks.Stack[int] = (*Stack[int])(nil) // Stack holds elements in an array-list -type Stack struct { - list *arraylist.List +type Stack[T comparable] struct { + list *arraylist.List[T] } // New instantiates a new empty stack -func New() *Stack { - return &Stack{list: arraylist.New()} +func New[T comparable]() *Stack[T] { + return &Stack[T]{list: arraylist.New[T]()} } // Push adds a value onto the top of the stack -func (stack *Stack) Push(value interface{}) { +func (stack *Stack[T]) Push(value T) { stack.list.Add(value) } // Pop removes top element on stack and returns it, or nil if stack is empty. // Second return parameter is true, unless the stack was empty and there was nothing to pop. -func (stack *Stack) Pop() (value interface{}, ok bool) { +func (stack *Stack[T]) Pop() (value T, ok bool) { value, ok = stack.list.Get(stack.list.Size() - 1) stack.list.Remove(stack.list.Size() - 1) return @@ -44,29 +45,29 @@ func (stack *Stack) Pop() (value interface{}, ok bool) { // Peek returns top element on the stack without removing it, or nil if stack is empty. // Second return parameter is true, unless the stack was empty and there was nothing to peek. -func (stack *Stack) Peek() (value interface{}, ok bool) { +func (stack *Stack[T]) Peek() (value T, ok bool) { return stack.list.Get(stack.list.Size() - 1) } // Empty returns true if stack does not contain any elements. -func (stack *Stack) Empty() bool { +func (stack *Stack[T]) Empty() bool { return stack.list.Empty() } // Size returns number of elements within the stack. -func (stack *Stack) Size() int { +func (stack *Stack[T]) Size() int { return stack.list.Size() } // Clear removes all elements from the stack. -func (stack *Stack) Clear() { +func (stack *Stack[T]) Clear() { stack.list.Clear() } // Values returns all elements in the stack (LIFO order). -func (stack *Stack) Values() []interface{} { +func (stack *Stack[T]) Values() []T { size := stack.list.Size() - elements := make([]interface{}, size, size) + elements := make([]T, size, size) for i := 1; i <= size; i++ { elements[size-i], _ = stack.list.Get(i - 1) // in reverse (LIFO) } @@ -74,7 +75,7 @@ func (stack *Stack) Values() []interface{} { } // String returns a string representation of container -func (stack *Stack) String() string { +func (stack *Stack[T]) String() string { str := "ArrayStack\n" values := []string{} for _, value := range stack.list.Values() { @@ -85,6 +86,6 @@ func (stack *Stack) String() string { } // Check that the index is within bounds of the list -func (stack *Stack) withinRange(index int) bool { +func (stack *Stack[T]) withinRange(index int) bool { return index >= 0 && index < stack.list.Size() } diff --git a/stacks/arraystack/arraystack_test.go b/stacks/arraystack/arraystack_test.go index eba42870..4e754608 100644 --- a/stacks/arraystack/arraystack_test.go +++ b/stacks/arraystack/arraystack_test.go @@ -6,13 +6,14 @@ package arraystack import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestStackPush(t *testing.T) { - stack := New() + stack := New[int]() if actualValue := stack.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -20,7 +21,7 @@ func TestStackPush(t *testing.T) { stack.Push(2) stack.Push(3) - if actualValue := stack.Values(); actualValue[0].(int) != 3 || actualValue[1].(int) != 2 || actualValue[2].(int) != 1 { + if actualValue := stack.Values(); actualValue[0] != 3 || actualValue[1] != 2 || actualValue[2] != 1 { t.Errorf("Got %v expected %v", actualValue, "[3,2,1]") } if actualValue := stack.Empty(); actualValue != false { @@ -35,8 +36,8 @@ func TestStackPush(t *testing.T) { } func TestStackPeek(t *testing.T) { - stack := New() - if actualValue, ok := stack.Peek(); actualValue != nil || ok { + stack := New[int]() + if actualValue, ok := stack.Peek(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } stack.Push(1) @@ -48,7 +49,7 @@ func TestStackPeek(t *testing.T) { } func TestStackPop(t *testing.T) { - stack := New() + stack := New[int]() stack.Push(1) stack.Push(2) stack.Push(3) @@ -62,7 +63,7 @@ func TestStackPop(t *testing.T) { if actualValue, ok := stack.Pop(); actualValue != 1 || !ok { t.Errorf("Got %v expected %v", actualValue, 1) } - if actualValue, ok := stack.Pop(); actualValue != nil || ok { + if actualValue, ok := stack.Pop(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := stack.Empty(); actualValue != true { @@ -74,7 +75,7 @@ func TestStackPop(t *testing.T) { } func TestStackIteratorOnEmpty(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty stack") @@ -82,7 +83,7 @@ func TestStackIteratorOnEmpty(t *testing.T) { } func TestStackIteratorNext(t *testing.T) { - stack := New() + stack := New[string]() stack.Push("a") stack.Push("b") stack.Push("c") @@ -125,7 +126,7 @@ func TestStackIteratorNext(t *testing.T) { } func TestStackIteratorPrev(t *testing.T) { - stack := New() + stack := New[string]() stack.Push("a") stack.Push("b") stack.Push("c") @@ -164,7 +165,7 @@ func TestStackIteratorPrev(t *testing.T) { } func TestStackIteratorBegin(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() it.Begin() stack.Push("a") @@ -180,7 +181,7 @@ func TestStackIteratorBegin(t *testing.T) { } func TestStackIteratorEnd(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() if index := it.Index(); index != -1 { @@ -207,7 +208,7 @@ func TestStackIteratorEnd(t *testing.T) { } func TestStackIteratorFirst(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -224,7 +225,7 @@ func TestStackIteratorFirst(t *testing.T) { } func TestStackIteratorLast(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -242,13 +243,13 @@ func TestStackIteratorLast(t *testing.T) { func TestStackIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - stack := New() + stack := New[string]() it := stack.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty stack") @@ -257,7 +258,7 @@ func TestStackIteratorNextTo(t *testing.T) { // NextTo (not found) { - stack := New() + stack := New[string]() stack.Push("xx") stack.Push("yy") it := stack.Iterator() @@ -268,7 +269,7 @@ func TestStackIteratorNextTo(t *testing.T) { // NextTo (found) { - stack := New() + stack := New[string]() stack.Push("aa") stack.Push("bb") stack.Push("cc") @@ -277,13 +278,13 @@ func TestStackIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty stack") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 2 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "aa") } if it.Next() { @@ -294,13 +295,13 @@ func TestStackIteratorNextTo(t *testing.T) { func TestStackIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - stack := New() + stack := New[string]() it := stack.Iterator() it.End() for it.PrevTo(seek) { @@ -310,7 +311,7 @@ func TestStackIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - stack := New() + stack := New[string]() stack.Push("xx") stack.Push("yy") it := stack.Iterator() @@ -322,7 +323,7 @@ func TestStackIteratorPrevTo(t *testing.T) { // PrevTo (found) { - stack := New() + stack := New[string]() stack.Push("aa") stack.Push("bb") stack.Push("cc") @@ -331,13 +332,13 @@ func TestStackIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty stack") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 0 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "cc") } if it.Prev() { @@ -347,16 +348,14 @@ func TestStackIteratorPrevTo(t *testing.T) { } func TestStackSerialization(t *testing.T) { - stack := New() + stack := New[string]() stack.Push("a") stack.Push("b") stack.Push("c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", stack.Values()...), "cba"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, stack.Values(), []string{"c", "b", "a"}) if actualValue, expectedValue := stack.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -378,21 +377,22 @@ func TestStackSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &stack) + err = json.Unmarshal([]byte(`["a","b","c"]`), &stack) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestStackString(t *testing.T) { - c := New() + c := New[int]() c.Push(1) if !strings.HasPrefix(c.String(), "ArrayStack") { t.Errorf("String should start with container name") } } -func benchmarkPush(b *testing.B, stack *Stack, size int) { +func benchmarkPush(b *testing.B, stack *Stack[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { stack.Push(n) @@ -400,7 +400,7 @@ func benchmarkPush(b *testing.B, stack *Stack, size int) { } } -func benchmarkPop(b *testing.B, stack *Stack, size int) { +func benchmarkPop(b *testing.B, stack *Stack[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { stack.Pop() @@ -411,7 +411,7 @@ func benchmarkPop(b *testing.B, stack *Stack, size int) { func BenchmarkArrayStackPop100(b *testing.B) { b.StopTimer() size := 100 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -422,7 +422,7 @@ func BenchmarkArrayStackPop100(b *testing.B) { func BenchmarkArrayStackPop1000(b *testing.B) { b.StopTimer() size := 1000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -433,7 +433,7 @@ func BenchmarkArrayStackPop1000(b *testing.B) { func BenchmarkArrayStackPop10000(b *testing.B) { b.StopTimer() size := 10000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -444,7 +444,7 @@ func BenchmarkArrayStackPop10000(b *testing.B) { func BenchmarkArrayStackPop100000(b *testing.B) { b.StopTimer() size := 100000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -455,7 +455,7 @@ func BenchmarkArrayStackPop100000(b *testing.B) { func BenchmarkArrayStackPush100(b *testing.B) { b.StopTimer() size := 100 - stack := New() + stack := New[int]() b.StartTimer() benchmarkPush(b, stack, size) } @@ -463,7 +463,7 @@ func BenchmarkArrayStackPush100(b *testing.B) { func BenchmarkArrayStackPush1000(b *testing.B) { b.StopTimer() size := 1000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -474,7 +474,7 @@ func BenchmarkArrayStackPush1000(b *testing.B) { func BenchmarkArrayStackPush10000(b *testing.B) { b.StopTimer() size := 10000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -485,7 +485,7 @@ func BenchmarkArrayStackPush10000(b *testing.B) { func BenchmarkArrayStackPush100000(b *testing.B) { b.StopTimer() size := 100000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } diff --git a/stacks/arraystack/iterator.go b/stacks/arraystack/iterator.go index c01d7c8b..651e41a0 100644 --- a/stacks/arraystack/iterator.go +++ b/stacks/arraystack/iterator.go @@ -4,27 +4,27 @@ package arraystack -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - stack *Stack +type Iterator[T comparable] struct { + stack *Stack[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (stack *Stack) Iterator() Iterator { - return Iterator{stack: stack, index: -1} +func (stack *Stack[T]) Iterator() *Iterator[T] { + return &Iterator[T]{stack: stack, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.stack.Size() { iterator.index++ } @@ -34,7 +34,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -43,33 +43,33 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { value, _ := iterator.stack.list.Get(iterator.stack.list.Size() - iterator.index - 1) // in reverse (LIFO) return value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.stack.Size() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -77,7 +77,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -86,7 +86,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -100,7 +100,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/stacks/arraystack/serialization.go b/stacks/arraystack/serialization.go index e65889fb..ff0c6385 100644 --- a/stacks/arraystack/serialization.go +++ b/stacks/arraystack/serialization.go @@ -5,29 +5,29 @@ package arraystack import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Stack)(nil) -var _ containers.JSONDeserializer = (*Stack)(nil) +var _ containers.JSONSerializer = (*Stack[int])(nil) +var _ containers.JSONDeserializer = (*Stack[int])(nil) // ToJSON outputs the JSON representation of the stack. -func (stack *Stack) ToJSON() ([]byte, error) { +func (stack *Stack[T]) ToJSON() ([]byte, error) { return stack.list.ToJSON() } // FromJSON populates the stack from the input JSON representation. -func (stack *Stack) FromJSON(data []byte) error { +func (stack *Stack[T]) FromJSON(data []byte) error { return stack.list.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (stack *Stack) UnmarshalJSON(bytes []byte) error { +func (stack *Stack[T]) UnmarshalJSON(bytes []byte) error { return stack.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (stack *Stack) MarshalJSON() ([]byte, error) { +func (stack *Stack[T]) MarshalJSON() ([]byte, error) { return stack.ToJSON() } diff --git a/stacks/linkedliststack/iterator.go b/stacks/linkedliststack/iterator.go index dea086b1..5350e330 100644 --- a/stacks/linkedliststack/iterator.go +++ b/stacks/linkedliststack/iterator.go @@ -4,27 +4,27 @@ package linkedliststack -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.IteratorWithIndex = (*Iterator)(nil) +var _ containers.IteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - stack *Stack +type Iterator[T comparable] struct { + stack *Stack[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (stack *Stack) Iterator() Iterator { - return Iterator{stack: stack, index: -1} +func (stack *Stack[T]) Iterator() *Iterator[T] { + return &Iterator[T]{stack: stack, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.stack.Size() { iterator.index++ } @@ -33,27 +33,27 @@ func (iterator *Iterator) Next() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { value, _ := iterator.stack.list.Get(iterator.index) // in reverse (LIFO) return value } // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -62,7 +62,7 @@ func (iterator *Iterator) First() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/stacks/linkedliststack/linkedliststack.go b/stacks/linkedliststack/linkedliststack.go index ce69b212..ec373dd5 100644 --- a/stacks/linkedliststack/linkedliststack.go +++ b/stacks/linkedliststack/linkedliststack.go @@ -11,32 +11,33 @@ package linkedliststack import ( "fmt" - "github.com/emirpasic/gods/lists/singlylinkedlist" - "github.com/emirpasic/gods/stacks" "strings" + + "github.com/emirpasic/gods/v2/lists/singlylinkedlist" + "github.com/emirpasic/gods/v2/stacks" ) // Assert Stack implementation -var _ stacks.Stack = (*Stack)(nil) +var _ stacks.Stack[int] = (*Stack[int])(nil) // Stack holds elements in a singly-linked-list -type Stack struct { - list *singlylinkedlist.List +type Stack[T comparable] struct { + list *singlylinkedlist.List[T] } // New nnstantiates a new empty stack -func New() *Stack { - return &Stack{list: &singlylinkedlist.List{}} +func New[T comparable]() *Stack[T] { + return &Stack[T]{list: singlylinkedlist.New[T]()} } // Push adds a value onto the top of the stack -func (stack *Stack) Push(value interface{}) { +func (stack *Stack[T]) Push(value T) { stack.list.Prepend(value) } // Pop removes top element on stack and returns it, or nil if stack is empty. // Second return parameter is true, unless the stack was empty and there was nothing to pop. -func (stack *Stack) Pop() (value interface{}, ok bool) { +func (stack *Stack[T]) Pop() (value T, ok bool) { value, ok = stack.list.Get(0) stack.list.Remove(0) return @@ -44,32 +45,32 @@ func (stack *Stack) Pop() (value interface{}, ok bool) { // Peek returns top element on the stack without removing it, or nil if stack is empty. // Second return parameter is true, unless the stack was empty and there was nothing to peek. -func (stack *Stack) Peek() (value interface{}, ok bool) { +func (stack *Stack[T]) Peek() (value T, ok bool) { return stack.list.Get(0) } // Empty returns true if stack does not contain any elements. -func (stack *Stack) Empty() bool { +func (stack *Stack[T]) Empty() bool { return stack.list.Empty() } // Size returns number of elements within the stack. -func (stack *Stack) Size() int { +func (stack *Stack[T]) Size() int { return stack.list.Size() } // Clear removes all elements from the stack. -func (stack *Stack) Clear() { +func (stack *Stack[T]) Clear() { stack.list.Clear() } // Values returns all elements in the stack (LIFO order). -func (stack *Stack) Values() []interface{} { +func (stack *Stack[T]) Values() []T { return stack.list.Values() } // String returns a string representation of container -func (stack *Stack) String() string { +func (stack *Stack[T]) String() string { str := "LinkedListStack\n" values := []string{} for _, value := range stack.list.Values() { @@ -80,6 +81,6 @@ func (stack *Stack) String() string { } // Check that the index is within bounds of the list -func (stack *Stack) withinRange(index int) bool { +func (stack *Stack[T]) withinRange(index int) bool { return index >= 0 && index < stack.list.Size() } diff --git a/stacks/linkedliststack/linkedliststack_test.go b/stacks/linkedliststack/linkedliststack_test.go index f491fd3a..f215ba37 100644 --- a/stacks/linkedliststack/linkedliststack_test.go +++ b/stacks/linkedliststack/linkedliststack_test.go @@ -6,13 +6,14 @@ package linkedliststack import ( "encoding/json" - "fmt" "strings" "testing" + + "github.com/emirpasic/gods/v2/testutils" ) func TestStackPush(t *testing.T) { - stack := New() + stack := New[int]() if actualValue := stack.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) } @@ -20,7 +21,7 @@ func TestStackPush(t *testing.T) { stack.Push(2) stack.Push(3) - if actualValue := stack.Values(); actualValue[0].(int) != 3 || actualValue[1].(int) != 2 || actualValue[2].(int) != 1 { + if actualValue := stack.Values(); actualValue[0] != 3 || actualValue[1] != 2 || actualValue[2] != 1 { t.Errorf("Got %v expected %v", actualValue, "[3,2,1]") } if actualValue := stack.Empty(); actualValue != false { @@ -35,8 +36,8 @@ func TestStackPush(t *testing.T) { } func TestStackPeek(t *testing.T) { - stack := New() - if actualValue, ok := stack.Peek(); actualValue != nil || ok { + stack := New[int]() + if actualValue, ok := stack.Peek(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } stack.Push(1) @@ -48,7 +49,7 @@ func TestStackPeek(t *testing.T) { } func TestStackPop(t *testing.T) { - stack := New() + stack := New[int]() stack.Push(1) stack.Push(2) stack.Push(3) @@ -62,7 +63,7 @@ func TestStackPop(t *testing.T) { if actualValue, ok := stack.Pop(); actualValue != 1 || !ok { t.Errorf("Got %v expected %v", actualValue, 1) } - if actualValue, ok := stack.Pop(); actualValue != nil || ok { + if actualValue, ok := stack.Pop(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := stack.Empty(); actualValue != true { @@ -74,7 +75,7 @@ func TestStackPop(t *testing.T) { } func TestStackIterator(t *testing.T) { - stack := New() + stack := New[string]() stack.Push("a") stack.Push("b") stack.Push("c") @@ -118,7 +119,7 @@ func TestStackIterator(t *testing.T) { } func TestStackIteratorBegin(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() it.Begin() stack.Push("a") @@ -134,7 +135,7 @@ func TestStackIteratorBegin(t *testing.T) { } func TestStackIteratorFirst(t *testing.T) { - stack := New() + stack := New[string]() it := stack.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -152,13 +153,13 @@ func TestStackIteratorFirst(t *testing.T) { func TestStackIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - stack := New() + stack := New[string]() it := stack.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty stack") @@ -167,7 +168,7 @@ func TestStackIteratorNextTo(t *testing.T) { // NextTo (not found) { - stack := New() + stack := New[string]() stack.Push("xx") stack.Push("yy") it := stack.Iterator() @@ -178,7 +179,7 @@ func TestStackIteratorNextTo(t *testing.T) { // NextTo (found) { - stack := New() + stack := New[string]() stack.Push("aa") stack.Push("bb") stack.Push("cc") @@ -187,13 +188,13 @@ func TestStackIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty stack") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 2 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "aa") } if it.Next() { @@ -203,16 +204,14 @@ func TestStackIteratorNextTo(t *testing.T) { } func TestStackSerialization(t *testing.T) { - stack := New() + stack := New[string]() stack.Push("a") stack.Push("b") stack.Push("c") var err error assert := func() { - if actualValue, expectedValue := fmt.Sprintf("%s%s%s", stack.Values()...), "cba"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } + testutils.SameElements(t, stack.Values(), []string{"c", "b", "a"}) if actualValue, expectedValue := stack.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -234,21 +233,22 @@ func TestStackSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &stack) + err = json.Unmarshal([]byte(`["a","b","c"]`), &stack) if err != nil { t.Errorf("Got error %v", err) } + assert() } func TestStackString(t *testing.T) { - c := New() + c := New[int]() c.Push(1) if !strings.HasPrefix(c.String(), "LinkedListStack") { t.Errorf("String should start with container name") } } -func benchmarkPush(b *testing.B, stack *Stack, size int) { +func benchmarkPush(b *testing.B, stack *Stack[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { stack.Push(n) @@ -256,7 +256,7 @@ func benchmarkPush(b *testing.B, stack *Stack, size int) { } } -func benchmarkPop(b *testing.B, stack *Stack, size int) { +func benchmarkPop(b *testing.B, stack *Stack[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { stack.Pop() @@ -267,7 +267,7 @@ func benchmarkPop(b *testing.B, stack *Stack, size int) { func BenchmarkLinkedListStackPop100(b *testing.B) { b.StopTimer() size := 100 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -278,7 +278,7 @@ func BenchmarkLinkedListStackPop100(b *testing.B) { func BenchmarkLinkedListStackPop1000(b *testing.B) { b.StopTimer() size := 1000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -289,7 +289,7 @@ func BenchmarkLinkedListStackPop1000(b *testing.B) { func BenchmarkLinkedListStackPop10000(b *testing.B) { b.StopTimer() size := 10000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -300,7 +300,7 @@ func BenchmarkLinkedListStackPop10000(b *testing.B) { func BenchmarkLinkedListStackPop100000(b *testing.B) { b.StopTimer() size := 100000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -311,7 +311,7 @@ func BenchmarkLinkedListStackPop100000(b *testing.B) { func BenchmarkLinkedListStackPush100(b *testing.B) { b.StopTimer() size := 100 - stack := New() + stack := New[int]() b.StartTimer() benchmarkPush(b, stack, size) } @@ -319,7 +319,7 @@ func BenchmarkLinkedListStackPush100(b *testing.B) { func BenchmarkLinkedListStackPush1000(b *testing.B) { b.StopTimer() size := 1000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -330,7 +330,7 @@ func BenchmarkLinkedListStackPush1000(b *testing.B) { func BenchmarkLinkedListStackPush10000(b *testing.B) { b.StopTimer() size := 10000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } @@ -341,7 +341,7 @@ func BenchmarkLinkedListStackPush10000(b *testing.B) { func BenchmarkLinkedListStackPush100000(b *testing.B) { b.StopTimer() size := 100000 - stack := New() + stack := New[int]() for n := 0; n < size; n++ { stack.Push(n) } diff --git a/stacks/linkedliststack/serialization.go b/stacks/linkedliststack/serialization.go index a82b768c..3a8ca044 100644 --- a/stacks/linkedliststack/serialization.go +++ b/stacks/linkedliststack/serialization.go @@ -5,29 +5,29 @@ package linkedliststack import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Stack)(nil) -var _ containers.JSONDeserializer = (*Stack)(nil) +var _ containers.JSONSerializer = (*Stack[int])(nil) +var _ containers.JSONDeserializer = (*Stack[int])(nil) // ToJSON outputs the JSON representation of the stack. -func (stack *Stack) ToJSON() ([]byte, error) { +func (stack *Stack[T]) ToJSON() ([]byte, error) { return stack.list.ToJSON() } // FromJSON populates the stack from the input JSON representation. -func (stack *Stack) FromJSON(data []byte) error { +func (stack *Stack[T]) FromJSON(data []byte) error { return stack.list.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (stack *Stack) UnmarshalJSON(bytes []byte) error { +func (stack *Stack[T]) UnmarshalJSON(bytes []byte) error { return stack.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (stack *Stack) MarshalJSON() ([]byte, error) { +func (stack *Stack[T]) MarshalJSON() ([]byte, error) { return stack.ToJSON() } diff --git a/stacks/stacks.go b/stacks/stacks.go index e9ae56e1..4acd0116 100644 --- a/stacks/stacks.go +++ b/stacks/stacks.go @@ -9,15 +9,15 @@ // Reference: https://en.wikipedia.org/wiki/Stack_%28abstract_data_type%29 package stacks -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Stack interface that all stacks implement -type Stack interface { - Push(value interface{}) - Pop() (value interface{}, ok bool) - Peek() (value interface{}, ok bool) +type Stack[T any] interface { + Push(value T) + Pop() (value T, ok bool) + Peek() (value T, ok bool) - containers.Container + containers.Container[T] // Empty() bool // Size() int // Clear() diff --git a/testutils/testutils.go b/testutils/testutils.go new file mode 100644 index 00000000..b97d6385 --- /dev/null +++ b/testutils/testutils.go @@ -0,0 +1,18 @@ +package testutils + +import "testing" + +func SameElements[T comparable](t *testing.T, actual, expected []T) { + if len(actual) != len(expected) { + t.Errorf("Got %d expected %d", len(actual), len(expected)) + } +outer: + for _, e := range expected { + for _, a := range actual { + if e == a { + continue outer + } + } + t.Errorf("Did not find expected element %v in %v", e, actual) + } +} diff --git a/trees/avltree/avltree.go b/trees/avltree/avltree.go index ec2765cd..81b9f3f8 100644 --- a/trees/avltree/avltree.go +++ b/trees/avltree/avltree.go @@ -10,68 +10,65 @@ package avltree import ( + "cmp" "fmt" - "github.com/emirpasic/gods/trees" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/trees" + "github.com/emirpasic/gods/v2/utils" ) // Assert Tree implementation -var _ trees.Tree = new(Tree) +var _ trees.Tree[int] = (*Tree[string, int])(nil) // Tree holds elements of the AVL tree. -type Tree struct { - Root *Node // Root node - Comparator utils.Comparator // Key comparator - size int // Total number of keys in the tree +type Tree[K comparable, V any] struct { + Root *Node[K, V] // Root node + Comparator utils.Comparator[K] // Key comparator + size int // Total number of keys in the tree } // Node is a single element within the tree -type Node struct { - Key interface{} - Value interface{} - Parent *Node // Parent node - Children [2]*Node // Children nodes +type Node[K comparable, V any] struct { + Key K + Value V + Parent *Node[K, V] // Parent node + Children [2]*Node[K, V] // Children nodes b int8 } -// NewWith instantiates an AVL tree with the custom comparator. -func NewWith(comparator utils.Comparator) *Tree { - return &Tree{Comparator: comparator} -} - -// NewWithIntComparator instantiates an AVL tree with the IntComparator, i.e. keys are of type int. -func NewWithIntComparator() *Tree { - return &Tree{Comparator: utils.IntComparator} +// New instantiates an AVL tree with the built-in comparator for K +func New[K cmp.Ordered, V any]() *Tree[K, V] { + return &Tree[K, V]{Comparator: cmp.Compare[K]} } -// NewWithStringComparator instantiates an AVL tree with the StringComparator, i.e. keys are of type string. -func NewWithStringComparator() *Tree { - return &Tree{Comparator: utils.StringComparator} +// NewWith instantiates an AVL tree with the custom comparator. +func NewWith[K comparable, V any](comparator utils.Comparator[K]) *Tree[K, V] { + return &Tree[K, V]{Comparator: comparator} } // Put inserts node into the tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Put(key interface{}, value interface{}) { - t.put(key, value, nil, &t.Root) +func (tree *Tree[K, V]) Put(key K, value V) { + tree.put(key, value, nil, &tree.Root) } // Get searches the node in the tree by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Get(key interface{}) (value interface{}, found bool) { - n := t.GetNode(key) +func (tree *Tree[K, V]) Get(key K) (value V, found bool) { + n := tree.GetNode(key) if n != nil { return n.Value, true } - return nil, false + return value, false } // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) GetNode(key interface{}) *Node { - n := t.Root +func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] { + n := tree.Root for n != nil { - cmp := t.Comparator(key, n.Key) + cmp := tree.Comparator(key, n.Key) switch { case cmp == 0: return n @@ -86,23 +83,23 @@ func (t *Tree) GetNode(key interface{}) *Node { // Remove remove the node from the tree by key. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Remove(key interface{}) { - t.remove(key, &t.Root) +func (tree *Tree[K, V]) Remove(key K) { + tree.remove(key, &tree.Root) } // Empty returns true if tree does not contain any nodes. -func (t *Tree) Empty() bool { - return t.size == 0 +func (tree *Tree[K, V]) Empty() bool { + return tree.size == 0 } // Size returns the number of elements stored in the tree. -func (t *Tree) Size() int { - return t.size +func (tree *Tree[K, V]) Size() int { + return tree.size } // Size returns the number of elements stored in the subtree. // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes. -func (n *Node) Size() int { +func (n *Node[K, V]) Size() int { if n == nil { return 0 } @@ -117,9 +114,9 @@ func (n *Node) Size() int { } // Keys returns all keys in-order -func (t *Tree) Keys() []interface{} { - keys := make([]interface{}, t.size) - it := t.Iterator() +func (tree *Tree[K, V]) Keys() []K { + keys := make([]K, tree.size) + it := tree.Iterator() for i := 0; it.Next(); i++ { keys[i] = it.Key() } @@ -127,9 +124,9 @@ func (t *Tree) Keys() []interface{} { } // Values returns all values in-order based on the key. -func (t *Tree) Values() []interface{} { - values := make([]interface{}, t.size) - it := t.Iterator() +func (tree *Tree[K, V]) Values() []V { + values := make([]V, tree.size) + it := tree.Iterator() for i := 0; it.Next(); i++ { values[i] = it.Value() } @@ -138,14 +135,14 @@ func (t *Tree) Values() []interface{} { // Left returns the minimum element of the AVL tree // or nil if the tree is empty. -func (t *Tree) Left() *Node { - return t.bottom(0) +func (tree *Tree[K, V]) Left() *Node[K, V] { + return tree.bottom(0) } // Right returns the maximum element of the AVL tree // or nil if the tree is empty. -func (t *Tree) Right() *Node { - return t.bottom(1) +func (tree *Tree[K, V]) Right() *Node[K, V] { + return tree.bottom(1) } // Floor Finds floor node of the input key, return the floor node or nil if no floor is found. @@ -156,11 +153,11 @@ func (t *Tree) Right() *Node { // all nodes in the tree is larger than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Floor(key interface{}) (floor *Node, found bool) { +func (tree *Tree[K, V]) Floor(key K) (floor *Node[K, V], found bool) { found = false - n := t.Root + n := tree.Root for n != nil { - c := t.Comparator(key, n.Key) + c := tree.Comparator(key, n.Key) switch { case c == 0: return n, true @@ -185,11 +182,11 @@ func (t *Tree) Floor(key interface{}) (floor *Node, found bool) { // all nodes in the tree is smaller than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Ceiling(key interface{}) (floor *Node, found bool) { +func (tree *Tree[K, V]) Ceiling(key K) (floor *Node[K, V], found bool) { found = false - n := t.Root + n := tree.Root for n != nil { - c := t.Comparator(key, n.Key) + c := tree.Comparator(key, n.Key) switch { case c == 0: return n, true @@ -207,33 +204,33 @@ func (t *Tree) Ceiling(key interface{}) (floor *Node, found bool) { } // Clear removes all nodes from the tree. -func (t *Tree) Clear() { - t.Root = nil - t.size = 0 +func (tree *Tree[K, V]) Clear() { + tree.Root = nil + tree.size = 0 } // String returns a string representation of container -func (t *Tree) String() string { +func (tree *Tree[K, V]) String() string { str := "AVLTree\n" - if !t.Empty() { - output(t.Root, "", true, &str) + if !tree.Empty() { + output(tree.Root, "", true, &str) } return str } -func (n *Node) String() string { +func (n *Node[K, V]) String() string { return fmt.Sprintf("%v", n.Key) } -func (t *Tree) put(key interface{}, value interface{}, p *Node, qp **Node) bool { +func (tree *Tree[K, V]) put(key K, value V, p *Node[K, V], qp **Node[K, V]) bool { q := *qp if q == nil { - t.size++ - *qp = &Node{Key: key, Value: value, Parent: p} + tree.size++ + *qp = &Node[K, V]{Key: key, Value: value, Parent: p} return true } - c := t.Comparator(key, q.Key) + c := tree.Comparator(key, q.Key) if c == 0 { q.Key = key q.Value = value @@ -247,22 +244,22 @@ func (t *Tree) put(key interface{}, value interface{}, p *Node, qp **Node) bool } a := (c + 1) / 2 var fix bool - fix = t.put(key, value, q, &q.Children[a]) + fix = tree.put(key, value, q, &q.Children[a]) if fix { return putFix(int8(c), qp) } return false } -func (t *Tree) remove(key interface{}, qp **Node) bool { +func (tree *Tree[K, V]) remove(key K, qp **Node[K, V]) bool { q := *qp if q == nil { return false } - c := t.Comparator(key, q.Key) + c := tree.Comparator(key, q.Key) if c == 0 { - t.size-- + tree.size-- if q.Children[1] == nil { if q.Children[0] != nil { q.Children[0].Parent = q.Parent @@ -283,14 +280,14 @@ func (t *Tree) remove(key interface{}, qp **Node) bool { c = 1 } a := (c + 1) / 2 - fix := t.remove(key, &q.Children[a]) + fix := tree.remove(key, &q.Children[a]) if fix { return removeFix(int8(-c), qp) } return false } -func removeMin(qp **Node, minKey *interface{}, minVal *interface{}) bool { +func removeMin[K comparable, V any](qp **Node[K, V], minKey *K, minVal *V) bool { q := *qp if q.Children[0] == nil { *minKey = q.Key @@ -308,7 +305,7 @@ func removeMin(qp **Node, minKey *interface{}, minVal *interface{}) bool { return false } -func putFix(c int8, t **Node) bool { +func putFix[K comparable, V any](c int8, t **Node[K, V]) bool { s := *t if s.b == 0 { s.b = c @@ -329,7 +326,7 @@ func putFix(c int8, t **Node) bool { return false } -func removeFix(c int8, t **Node) bool { +func removeFix[K comparable, V any](c int8, t **Node[K, V]) bool { s := *t if s.b == 0 { s.b = c @@ -358,14 +355,14 @@ func removeFix(c int8, t **Node) bool { return true } -func singlerot(c int8, s *Node) *Node { +func singlerot[K comparable, V any](c int8, s *Node[K, V]) *Node[K, V] { s.b = 0 s = rotate(c, s) s.b = 0 return s } -func doublerot(c int8, s *Node) *Node { +func doublerot[K comparable, V any](c int8, s *Node[K, V]) *Node[K, V] { a := (c + 1) / 2 r := s.Children[a] s.Children[a] = rotate(-c, s.Children[a]) @@ -387,7 +384,7 @@ func doublerot(c int8, s *Node) *Node { return p } -func rotate(c int8, s *Node) *Node { +func rotate[K comparable, V any](c int8, s *Node[K, V]) *Node[K, V] { a := (c + 1) / 2 r := s.Children[a] s.Children[a] = r.Children[a^1] @@ -400,8 +397,8 @@ func rotate(c int8, s *Node) *Node { return r } -func (t *Tree) bottom(d int) *Node { - n := t.Root +func (tree *Tree[K, V]) bottom(d int) *Node[K, V] { + n := tree.Root if n == nil { return nil } @@ -414,17 +411,17 @@ func (t *Tree) bottom(d int) *Node { // Prev returns the previous element in an inorder // walk of the AVL tree. -func (n *Node) Prev() *Node { +func (n *Node[K, V]) Prev() *Node[K, V] { return n.walk1(0) } // Next returns the next element in an inorder // walk of the AVL tree. -func (n *Node) Next() *Node { +func (n *Node[K, V]) Next() *Node[K, V] { return n.walk1(1) } -func (n *Node) walk1(a int) *Node { +func (n *Node[K, V]) walk1(a int) *Node[K, V] { if n == nil { return nil } @@ -445,7 +442,7 @@ func (n *Node) walk1(a int) *Node { return p } -func output(node *Node, prefix string, isTail bool, str *string) { +func output[K comparable, V any](node *Node[K, V], prefix string, isTail bool, str *string) { if node.Children[1] != nil { newPrefix := prefix if isTail { diff --git a/trees/avltree/avltree_test.go b/trees/avltree/avltree_test.go index 114b5a5e..346fc479 100644 --- a/trees/avltree/avltree_test.go +++ b/trees/avltree/avltree_test.go @@ -5,14 +5,13 @@ package avltree import ( "encoding/json" - "fmt" - "github.com/emirpasic/gods/utils" + "slices" "strings" "testing" ) func TestAVLTreeGet(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() if actualValue := tree.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) @@ -56,7 +55,7 @@ func TestAVLTreeGet(t *testing.T) { } func TestAVLTreePut(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -69,10 +68,10 @@ func TestAVLTreePut(t *testing.T) { if actualValue := tree.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := fmt.Sprintf("%d%d%d%d%d%d%d", tree.Keys()...), "1234567"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{1, 2, 3, 4, 5, 6, 7}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s%s%s%s", tree.Values()...), "abcdefg"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -84,12 +83,12 @@ func TestAVLTreePut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := tree.Get(test[0]) + actualValue, actualFound := tree.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -97,7 +96,7 @@ func TestAVLTreePut(t *testing.T) { } func TestAVLTreeRemove(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -113,13 +112,10 @@ func TestAVLTreeRemove(t *testing.T) { tree.Remove(8) tree.Remove(5) - if actualValue, expectedValue := fmt.Sprintf("%d%d%d%d", tree.Keys()...), "1234"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{1, 2, 3, 4}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", tree.Values()...), "abcd"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", tree.Values()...), "abcd"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{"a", "b", "c", "d"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue := tree.Size(); actualValue != 4 { @@ -131,14 +127,14 @@ func TestAVLTreeRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := tree.Get(test[0]) + actualValue, actualFound := tree.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -151,10 +147,10 @@ func TestAVLTreeRemove(t *testing.T) { tree.Remove(2) tree.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Keys()), "[]"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Values()), "[]"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if empty, size := tree.Empty(), tree.Size(); empty != true || size != -0 { @@ -164,7 +160,7 @@ func TestAVLTreeRemove(t *testing.T) { } func TestAVLTreeLeftAndRight(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() if actualValue := tree.Left(); actualValue != nil { t.Errorf("Got %v expected %v", actualValue, nil) @@ -182,23 +178,23 @@ func TestAVLTreeLeftAndRight(t *testing.T) { tree.Put(1, "x") // overwrite tree.Put(2, "b") - if actualValue, expectedValue := fmt.Sprintf("%d", tree.Left().Key), "1"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Left().Key, 1; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Left().Value), "x"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Left().Value, "x"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%d", tree.Right().Key), "7"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Right().Key, 7; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Right().Value), "g"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Right().Value, "g"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestAVLTreeCeilingAndFloor(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() if node, found := tree.Floor(0); node != nil || found { t.Errorf("Got %v expected %v", node, "") @@ -231,7 +227,7 @@ func TestAVLTreeCeilingAndFloor(t *testing.T) { } func TestAVLTreeIteratorNextOnEmpty(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty tree") @@ -239,7 +235,7 @@ func TestAVLTreeIteratorNextOnEmpty(t *testing.T) { } func TestAVLTreeIteratorPrevOnEmpty(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty tree") @@ -247,7 +243,7 @@ func TestAVLTreeIteratorPrevOnEmpty(t *testing.T) { } func TestAVLTreeIterator1Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -279,7 +275,7 @@ func TestAVLTreeIterator1Next(t *testing.T) { } func TestAVLTreeIterator1Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -312,7 +308,7 @@ func TestAVLTreeIterator1Prev(t *testing.T) { } func TestAVLTreeIterator2Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -331,7 +327,7 @@ func TestAVLTreeIterator2Next(t *testing.T) { } func TestAVLTreeIterator2Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -352,7 +348,7 @@ func TestAVLTreeIterator2Prev(t *testing.T) { } func TestAVLTreeIterator3Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(1, "a") it := tree.Iterator() count := 0 @@ -369,7 +365,7 @@ func TestAVLTreeIterator3Next(t *testing.T) { } func TestAVLTreeIterator3Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(1, "a") it := tree.Iterator() for it.Next() { @@ -388,7 +384,7 @@ func TestAVLTreeIterator3Prev(t *testing.T) { } func TestAVLTreeIterator4Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, int]() tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -424,7 +420,7 @@ func TestAVLTreeIterator4Next(t *testing.T) { } func TestAVLTreeIterator4Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, int]() tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -462,20 +458,20 @@ func TestAVLTreeIterator4Prev(t *testing.T) { } func TestAVLTreeIteratorBegin(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") it := tree.Iterator() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } it.Begin() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } for it.Next() { @@ -483,8 +479,8 @@ func TestAVLTreeIteratorBegin(t *testing.T) { it.Begin() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } it.Next() @@ -494,24 +490,24 @@ func TestAVLTreeIteratorBegin(t *testing.T) { } func TestAVLTreeIteratorEnd(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } it.End() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") it.End() - if it.Key() != nil { - t.Errorf("Got %v expected %v", it.Key(), nil) + if it.Key() != 0 { + t.Errorf("Got %v expected %v", it.Key(), 0) } it.Prev() @@ -521,7 +517,7 @@ func TestAVLTreeIteratorEnd(t *testing.T) { } func TestAVLTreeIteratorFirst(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -535,7 +531,7 @@ func TestAVLTreeIteratorFirst(t *testing.T) { } func TestAVLTreeIteratorLast(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -550,13 +546,13 @@ func TestAVLTreeIteratorLast(t *testing.T) { func TestAVLTreeIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") @@ -565,7 +561,7 @@ func TestAVLTreeIteratorNextTo(t *testing.T) { // NextTo (not found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -576,7 +572,7 @@ func TestAVLTreeIteratorNextTo(t *testing.T) { // NextTo (found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -585,13 +581,13 @@ func TestAVLTreeIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -602,13 +598,13 @@ func TestAVLTreeIteratorNextTo(t *testing.T) { func TestAVLTreeIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() it.End() for it.PrevTo(seek) { @@ -618,7 +614,7 @@ func TestAVLTreeIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -630,7 +626,7 @@ func TestAVLTreeIteratorPrevTo(t *testing.T) { // PrevTo (found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -639,13 +635,13 @@ func TestAVLTreeIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -655,8 +651,7 @@ func TestAVLTreeIteratorPrevTo(t *testing.T) { } func TestAVLTreeSerialization(t *testing.T) { - tree := NewWith(utils.StringComparator) - tree = NewWithStringComparator() + tree := New[string, string]() tree.Put("c", "3") tree.Put("b", "2") tree.Put("a", "1") @@ -666,11 +661,11 @@ func TestAVLTreeSerialization(t *testing.T) { if actualValue, expectedValue := tree.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Keys(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { - t.Errorf("Got %v expected %v", actualValue, "[a,b,c]") + if actualValue, expectedValue := tree.Keys(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Values(); actualValue[0].(string) != "1" || actualValue[1].(string) != "2" || actualValue[2].(string) != "3" { - t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") + if actualValue, expectedValue := tree.Values(), []string{"1", "2", "3"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if err != nil { t.Errorf("Got error %v", err) @@ -690,14 +685,24 @@ func TestAVLTreeSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`{"a":1,"b":2}`), &tree) + intTree := New[string, int]() + err = json.Unmarshal([]byte(`{"a":1,"b":2}`), intTree) if err != nil { t.Errorf("Got error %v", err) } + if actualValue, expectedValue := intTree.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Keys(), []string{"a", "b"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Values(), []int{1, 2}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } } func TestAVLTreeString(t *testing.T) { - c := NewWithIntComparator() + c := New[int, int]() c.Put(1, 1) c.Put(2, 1) c.Put(3, 1) @@ -712,7 +717,7 @@ func TestAVLTreeString(t *testing.T) { } } -func benchmarkGet(b *testing.B, tree *Tree, size int) { +func benchmarkGet(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Get(n) @@ -720,7 +725,7 @@ func benchmarkGet(b *testing.B, tree *Tree, size int) { } } -func benchmarkPut(b *testing.B, tree *Tree, size int) { +func benchmarkPut(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Put(n, struct{}{}) @@ -728,7 +733,7 @@ func benchmarkPut(b *testing.B, tree *Tree, size int) { } } -func benchmarkRemove(b *testing.B, tree *Tree, size int) { +func benchmarkRemove(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Remove(n) @@ -739,7 +744,7 @@ func benchmarkRemove(b *testing.B, tree *Tree, size int) { func BenchmarkAVLTreeGet100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -750,7 +755,7 @@ func BenchmarkAVLTreeGet100(b *testing.B) { func BenchmarkAVLTreeGet1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -761,7 +766,7 @@ func BenchmarkAVLTreeGet1000(b *testing.B) { func BenchmarkAVLTreeGet10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -772,7 +777,7 @@ func BenchmarkAVLTreeGet10000(b *testing.B) { func BenchmarkAVLTreeGet100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -783,7 +788,7 @@ func BenchmarkAVLTreeGet100000(b *testing.B) { func BenchmarkAVLTreePut100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() b.StartTimer() benchmarkPut(b, tree, size) } @@ -791,7 +796,7 @@ func BenchmarkAVLTreePut100(b *testing.B) { func BenchmarkAVLTreePut1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -802,7 +807,7 @@ func BenchmarkAVLTreePut1000(b *testing.B) { func BenchmarkAVLTreePut10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -813,7 +818,7 @@ func BenchmarkAVLTreePut10000(b *testing.B) { func BenchmarkAVLTreePut100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -824,7 +829,7 @@ func BenchmarkAVLTreePut100000(b *testing.B) { func BenchmarkAVLTreeRemove100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -835,7 +840,7 @@ func BenchmarkAVLTreeRemove100(b *testing.B) { func BenchmarkAVLTreeRemove1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -846,7 +851,7 @@ func BenchmarkAVLTreeRemove1000(b *testing.B) { func BenchmarkAVLTreeRemove10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -857,7 +862,7 @@ func BenchmarkAVLTreeRemove10000(b *testing.B) { func BenchmarkAVLTreeRemove100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } diff --git a/trees/avltree/iterator.go b/trees/avltree/iterator.go index 0186e40b..8541ce0a 100644 --- a/trees/avltree/iterator.go +++ b/trees/avltree/iterator.go @@ -4,15 +4,15 @@ package avltree -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - tree *Tree - node *Node +type Iterator[K comparable, V any] struct { + tree *Tree[K, V] + node *Node[K, V] position position } @@ -23,15 +23,15 @@ const ( ) // Iterator returns a stateful iterator whose elements are key/value pairs. -func (tree *Tree) Iterator() containers.ReverseIteratorWithKey { - return &Iterator{tree: tree, node: nil, position: begin} +func (tree *Tree[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{tree: tree, node: nil, position: begin} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { switch iterator.position { case begin: iterator.position = between @@ -51,7 +51,7 @@ func (iterator *Iterator) Next() bool { // If Prev() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Prev() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { switch iterator.position { case end: iterator.position = between @@ -69,38 +69,38 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[K, V]) Value() (v V) { if iterator.node == nil { - return nil + return v } return iterator.node.Value } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() (k K) { if iterator.node == nil { - return nil + return k } return iterator.node.Key } // Node returns the current element's node. // Does not modify the state of the iterator. -func (iterator *Iterator) Node() *Node { +func (iterator *Iterator[K, V]) Node() *Node[K, V] { return iterator.node } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.node = nil iterator.position = begin } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.node = nil iterator.position = end } @@ -108,7 +108,7 @@ func (iterator *Iterator) End() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { iterator.Begin() return iterator.Next() } @@ -116,7 +116,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { iterator.End() return iterator.Prev() } @@ -125,7 +125,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -139,7 +139,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/trees/avltree/serialization.go b/trees/avltree/serialization.go index 257c4040..4b3da5d3 100644 --- a/trees/avltree/serialization.go +++ b/trees/avltree/serialization.go @@ -6,43 +6,46 @@ package avltree import ( "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Tree)(nil) -var _ containers.JSONDeserializer = (*Tree)(nil) +var _ containers.JSONSerializer = (*Tree[string, int])(nil) +var _ containers.JSONDeserializer = (*Tree[string, int])(nil) // ToJSON outputs the JSON representation of the tree. -func (tree *Tree) ToJSON() ([]byte, error) { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) ToJSON() ([]byte, error) { + elements := make(map[K]V) it := tree.Iterator() for it.Next() { - elements[utils.ToString(it.Key())] = it.Value() + elements[it.Key()] = it.Value() } return json.Marshal(&elements) } // FromJSON populates the tree from the input JSON representation. -func (tree *Tree) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) FromJSON(data []byte) error { + elements := make(map[K]V) err := json.Unmarshal(data, &elements) - if err == nil { - tree.Clear() - for key, value := range elements { - tree.Put(key, value) - } + if err != nil { + return err + } + + tree.Clear() + for key, value := range elements { + tree.Put(key, value) } - return err + + return nil } // UnmarshalJSON @implements json.Unmarshaler -func (tree *Tree) UnmarshalJSON(bytes []byte) error { +func (tree *Tree[K, V]) UnmarshalJSON(bytes []byte) error { return tree.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (tree *Tree) MarshalJSON() ([]byte, error) { +func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) { return tree.ToJSON() } diff --git a/trees/binaryheap/binaryheap.go b/trees/binaryheap/binaryheap.go index e658f257..9f1605cd 100644 --- a/trees/binaryheap/binaryheap.go +++ b/trees/binaryheap/binaryheap.go @@ -12,39 +12,36 @@ package binaryheap import ( + "cmp" "fmt" - "github.com/emirpasic/gods/lists/arraylist" - "github.com/emirpasic/gods/trees" - "github.com/emirpasic/gods/utils" "strings" + + "github.com/emirpasic/gods/v2/lists/arraylist" + "github.com/emirpasic/gods/v2/trees" + "github.com/emirpasic/gods/v2/utils" ) // Assert Tree implementation -var _ trees.Tree = (*Heap)(nil) +var _ trees.Tree[int] = (*Heap[int])(nil) // Heap holds elements in an array-list -type Heap struct { - list *arraylist.List - Comparator utils.Comparator +type Heap[T comparable] struct { + list *arraylist.List[T] + Comparator utils.Comparator[T] } -// NewWith instantiates a new empty heap tree with the custom comparator. -func NewWith(comparator utils.Comparator) *Heap { - return &Heap{list: arraylist.New(), Comparator: comparator} +// New instantiates a new empty heap tree with the built-in comparator for T +func New[T cmp.Ordered]() *Heap[T] { + return &Heap[T]{list: arraylist.New[T](), Comparator: cmp.Compare[T]} } -// NewWithIntComparator instantiates a new empty heap with the IntComparator, i.e. elements are of type int. -func NewWithIntComparator() *Heap { - return &Heap{list: arraylist.New(), Comparator: utils.IntComparator} -} - -// NewWithStringComparator instantiates a new empty heap with the StringComparator, i.e. elements are of type string. -func NewWithStringComparator() *Heap { - return &Heap{list: arraylist.New(), Comparator: utils.StringComparator} +// NewWith instantiates a new empty heap tree with the custom comparator. +func NewWith[T comparable](comparator utils.Comparator[T]) *Heap[T] { + return &Heap[T]{list: arraylist.New[T](), Comparator: comparator} } // Push adds a value onto the heap and bubbles it up accordingly. -func (heap *Heap) Push(values ...interface{}) { +func (heap *Heap[T]) Push(values ...T) { if len(values) == 1 { heap.list.Add(values[0]) heap.bubbleUp() @@ -62,7 +59,7 @@ func (heap *Heap) Push(values ...interface{}) { // Pop removes top element on heap and returns it, or nil if heap is empty. // Second return parameter is true, unless the heap was empty and there was nothing to pop. -func (heap *Heap) Pop() (value interface{}, ok bool) { +func (heap *Heap[T]) Pop() (value T, ok bool) { value, ok = heap.list.Get(0) if !ok { return @@ -76,28 +73,28 @@ func (heap *Heap) Pop() (value interface{}, ok bool) { // Peek returns top element on the heap without removing it, or nil if heap is empty. // Second return parameter is true, unless the heap was empty and there was nothing to peek. -func (heap *Heap) Peek() (value interface{}, ok bool) { +func (heap *Heap[T]) Peek() (value T, ok bool) { return heap.list.Get(0) } // Empty returns true if heap does not contain any elements. -func (heap *Heap) Empty() bool { +func (heap *Heap[T]) Empty() bool { return heap.list.Empty() } // Size returns number of elements within the heap. -func (heap *Heap) Size() int { +func (heap *Heap[T]) Size() int { return heap.list.Size() } // Clear removes all elements from the heap. -func (heap *Heap) Clear() { +func (heap *Heap[T]) Clear() { heap.list.Clear() } // Values returns all elements in the heap. -func (heap *Heap) Values() []interface{} { - values := make([]interface{}, heap.list.Size(), heap.list.Size()) +func (heap *Heap[T]) Values() []T { + values := make([]T, heap.list.Size(), heap.list.Size()) for it := heap.Iterator(); it.Next(); { values[it.Index()] = it.Value() } @@ -105,7 +102,7 @@ func (heap *Heap) Values() []interface{} { } // String returns a string representation of container -func (heap *Heap) String() string { +func (heap *Heap[T]) String() string { str := "BinaryHeap\n" values := []string{} for it := heap.Iterator(); it.Next(); { @@ -117,13 +114,13 @@ func (heap *Heap) String() string { // Performs the "bubble down" operation. This is to place the element that is at the root // of the heap in its correct place so that the heap maintains the min/max-heap order property. -func (heap *Heap) bubbleDown() { +func (heap *Heap[T]) bubbleDown() { heap.bubbleDownIndex(0) } // Performs the "bubble down" operation. This is to place the element that is at the index // of the heap in its correct place so that the heap maintains the min/max-heap order property. -func (heap *Heap) bubbleDownIndex(index int) { +func (heap *Heap[T]) bubbleDownIndex(index int) { size := heap.list.Size() for leftIndex := index<<1 + 1; leftIndex < size; leftIndex = index<<1 + 1 { rightIndex := index<<1 + 2 @@ -147,7 +144,7 @@ func (heap *Heap) bubbleDownIndex(index int) { // Performs the "bubble up" operation. This is to place a newly inserted // element (i.e. last element in the list) in its correct place so that // the heap maintains the min/max-heap order property. -func (heap *Heap) bubbleUp() { +func (heap *Heap[T]) bubbleUp() { index := heap.list.Size() - 1 for parentIndex := (index - 1) >> 1; index > 0; parentIndex = (index - 1) >> 1 { indexValue, _ := heap.list.Get(index) @@ -161,6 +158,6 @@ func (heap *Heap) bubbleUp() { } // Check that the index is within bounds of the list -func (heap *Heap) withinRange(index int) bool { +func (heap *Heap[T]) withinRange(index int) bool { return index >= 0 && index < heap.list.Size() } diff --git a/trees/binaryheap/binaryheap_test.go b/trees/binaryheap/binaryheap_test.go index bb5c42b1..5ee8df4e 100644 --- a/trees/binaryheap/binaryheap_test.go +++ b/trees/binaryheap/binaryheap_test.go @@ -7,12 +7,13 @@ package binaryheap import ( "encoding/json" "math/rand" + "slices" "strings" "testing" ) func TestBinaryHeapPush(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() if actualValue := heap.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) @@ -22,8 +23,8 @@ func TestBinaryHeapPush(t *testing.T) { heap.Push(2) heap.Push(1) - if actualValue := heap.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { - t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") + if actualValue, expectedValue := heap.Values(), []int{1, 2, 3}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue := heap.Empty(); actualValue != false { t.Errorf("Got %v expected %v", actualValue, false) @@ -37,12 +38,12 @@ func TestBinaryHeapPush(t *testing.T) { } func TestBinaryHeapPushBulk(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() heap.Push(15, 20, 3, 1, 2) - if actualValue := heap.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { - t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") + if actualValue, expectedValue := heap.Values(), []int{1, 2, 3, 15, 20}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue, ok := heap.Pop(); actualValue != 1 || !ok { t.Errorf("Got %v expected %v", actualValue, 1) @@ -50,7 +51,7 @@ func TestBinaryHeapPushBulk(t *testing.T) { } func TestBinaryHeapPop(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() if actualValue := heap.Empty(); actualValue != true { t.Errorf("Got %v expected %v", actualValue, true) @@ -70,7 +71,7 @@ func TestBinaryHeapPop(t *testing.T) { if actualValue, ok := heap.Pop(); actualValue != 3 || !ok { t.Errorf("Got %v expected %v", actualValue, 3) } - if actualValue, ok := heap.Pop(); actualValue != nil || ok { + if actualValue, ok := heap.Pop(); actualValue != 0 || ok { t.Errorf("Got %v expected %v", actualValue, nil) } if actualValue := heap.Empty(); actualValue != true { @@ -82,7 +83,7 @@ func TestBinaryHeapPop(t *testing.T) { } func TestBinaryHeapRandom(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() rand.Seed(3) for i := 0; i < 10000; i++ { @@ -93,7 +94,7 @@ func TestBinaryHeapRandom(t *testing.T) { prev, _ := heap.Pop() for !heap.Empty() { curr, _ := heap.Pop() - if prev.(int) > curr.(int) { + if prev > curr { t.Errorf("Heap property invalidated. prev: %v current: %v", prev, curr) } prev = curr @@ -101,7 +102,7 @@ func TestBinaryHeapRandom(t *testing.T) { } func TestBinaryHeapIteratorOnEmpty(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() it := heap.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty heap") @@ -109,7 +110,7 @@ func TestBinaryHeapIteratorOnEmpty(t *testing.T) { } func TestBinaryHeapIteratorNext(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() heap.Push(3) heap.Push(2) heap.Push(1) @@ -146,7 +147,7 @@ func TestBinaryHeapIteratorNext(t *testing.T) { } func TestBinaryHeapIteratorPrev(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() heap.Push(3) heap.Push(2) heap.Push(1) @@ -185,7 +186,7 @@ func TestBinaryHeapIteratorPrev(t *testing.T) { } func TestBinaryHeapIteratorBegin(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() it := heap.Iterator() it.Begin() heap.Push(2) @@ -201,7 +202,7 @@ func TestBinaryHeapIteratorBegin(t *testing.T) { } func TestBinaryHeapIteratorEnd(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() it := heap.Iterator() if index := it.Index(); index != -1 { @@ -228,7 +229,7 @@ func TestBinaryHeapIteratorEnd(t *testing.T) { } func TestBinaryHeapIteratorFirst(t *testing.T) { - heap := NewWithIntComparator() + heap := New[int]() it := heap.Iterator() if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -245,7 +246,7 @@ func TestBinaryHeapIteratorFirst(t *testing.T) { } func TestBinaryHeapIteratorLast(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int]() it := tree.Iterator() if actualValue, expectedValue := it.Last(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) @@ -263,13 +264,13 @@ func TestBinaryHeapIteratorLast(t *testing.T) { func TestBinaryHeapIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - tree := NewWithStringComparator() + tree := New[string]() it := tree.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") @@ -278,7 +279,7 @@ func TestBinaryHeapIteratorNextTo(t *testing.T) { // NextTo (not found) { - tree := NewWithStringComparator() + tree := New[string]() tree.Push("xx") tree.Push("yy") it := tree.Iterator() @@ -289,7 +290,7 @@ func TestBinaryHeapIteratorNextTo(t *testing.T) { // NextTo (found) { - tree := NewWithStringComparator() + tree := New[string]() tree.Push("aa") tree.Push("bb") tree.Push("cc") @@ -298,13 +299,13 @@ func TestBinaryHeapIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Index(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -315,13 +316,13 @@ func TestBinaryHeapIteratorNextTo(t *testing.T) { func TestBinaryHeapIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index int, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - tree := NewWithStringComparator() + tree := New[string]() it := tree.Iterator() it.End() for it.PrevTo(seek) { @@ -331,7 +332,7 @@ func TestBinaryHeapIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - tree := NewWithStringComparator() + tree := New[string]() tree.Push("xx") tree.Push("yy") it := tree.Iterator() @@ -343,7 +344,7 @@ func TestBinaryHeapIteratorPrevTo(t *testing.T) { // PrevTo (found) { - tree := NewWithStringComparator() + tree := New[string]() tree.Push("aa") tree.Push("bb") tree.Push("cc") @@ -352,13 +353,13 @@ func TestBinaryHeapIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty list") } - if index, value := it.Index(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Index(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Index(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Index(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -368,7 +369,7 @@ func TestBinaryHeapIteratorPrevTo(t *testing.T) { } func TestBinaryHeapSerialization(t *testing.T) { - heap := NewWithStringComparator() + heap := New[string]() heap.Push("c") heap.Push("b") @@ -376,8 +377,8 @@ func TestBinaryHeapSerialization(t *testing.T) { var err error assert := func() { - if actualValue := heap.Values(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { - t.Errorf("Got %v expected %v", actualValue, "[1,3,2]") + if actualValue, expectedValue := heap.Values(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue := heap.Size(); actualValue != 3 { t.Errorf("Got %v expected %v", actualValue, 3) @@ -403,21 +404,25 @@ func TestBinaryHeapSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`[1,2,3]`), &heap) + intHeap := New[int]() + err = json.Unmarshal([]byte(`[1,2,3]`), &intHeap) if err != nil { t.Errorf("Got error %v", err) } + if actualValue, expectedValue := intHeap.Values(), []int{1, 2, 3}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } } func TestBTreeString(t *testing.T) { - c := NewWithIntComparator() + c := New[int]() c.Push(1) if !strings.HasPrefix(c.String(), "BinaryHeap") { t.Errorf("String should start with container name") } } -func benchmarkPush(b *testing.B, heap *Heap, size int) { +func benchmarkPush(b *testing.B, heap *Heap[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { heap.Push(n) @@ -425,7 +430,7 @@ func benchmarkPush(b *testing.B, heap *Heap, size int) { } } -func benchmarkPop(b *testing.B, heap *Heap, size int) { +func benchmarkPop(b *testing.B, heap *Heap[int], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { heap.Pop() @@ -436,7 +441,7 @@ func benchmarkPop(b *testing.B, heap *Heap, size int) { func BenchmarkBinaryHeapPop100(b *testing.B) { b.StopTimer() size := 100 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -447,7 +452,7 @@ func BenchmarkBinaryHeapPop100(b *testing.B) { func BenchmarkBinaryHeapPop1000(b *testing.B) { b.StopTimer() size := 1000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -458,7 +463,7 @@ func BenchmarkBinaryHeapPop1000(b *testing.B) { func BenchmarkBinaryHeapPop10000(b *testing.B) { b.StopTimer() size := 10000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -469,7 +474,7 @@ func BenchmarkBinaryHeapPop10000(b *testing.B) { func BenchmarkBinaryHeapPop100000(b *testing.B) { b.StopTimer() size := 100000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -480,7 +485,7 @@ func BenchmarkBinaryHeapPop100000(b *testing.B) { func BenchmarkBinaryHeapPush100(b *testing.B) { b.StopTimer() size := 100 - heap := NewWithIntComparator() + heap := New[int]() b.StartTimer() benchmarkPush(b, heap, size) } @@ -488,7 +493,7 @@ func BenchmarkBinaryHeapPush100(b *testing.B) { func BenchmarkBinaryHeapPush1000(b *testing.B) { b.StopTimer() size := 1000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -499,7 +504,7 @@ func BenchmarkBinaryHeapPush1000(b *testing.B) { func BenchmarkBinaryHeapPush10000(b *testing.B) { b.StopTimer() size := 10000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } @@ -510,7 +515,7 @@ func BenchmarkBinaryHeapPush10000(b *testing.B) { func BenchmarkBinaryHeapPush100000(b *testing.B) { b.StopTimer() size := 100000 - heap := NewWithIntComparator() + heap := New[int]() for n := 0; n < size; n++ { heap.Push(n) } diff --git a/trees/binaryheap/iterator.go b/trees/binaryheap/iterator.go index f2179633..73ff1815 100644 --- a/trees/binaryheap/iterator.go +++ b/trees/binaryheap/iterator.go @@ -5,28 +5,28 @@ package binaryheap import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Iterator implementation -var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) +var _ containers.ReverseIteratorWithIndex[int] = (*Iterator[int])(nil) // Iterator returns a stateful iterator whose values can be fetched by an index. -type Iterator struct { - heap *Heap +type Iterator[T comparable] struct { + heap *Heap[T] index int } // Iterator returns a stateful iterator whose values can be fetched by an index. -func (heap *Heap) Iterator() Iterator { - return Iterator{heap: heap, index: -1} +func (heap *Heap[T]) Iterator() *Iterator[T] { + return &Iterator[T]{heap: heap, index: -1} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's index and value can be retrieved by Index() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[T]) Next() bool { if iterator.index < iterator.heap.Size() { iterator.index++ } @@ -36,7 +36,7 @@ func (iterator *Iterator) Next() bool { // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[T]) Prev() bool { if iterator.index >= 0 { iterator.index-- } @@ -45,7 +45,7 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[T]) Value() T { start, end := evaluateRange(iterator.index) if end > iterator.heap.Size() { end = iterator.heap.Size() @@ -64,26 +64,26 @@ func (iterator *Iterator) Value() interface{} { // Index returns the current element's index. // Does not modify the state of the iterator. -func (iterator *Iterator) Index() int { +func (iterator *Iterator[T]) Index() int { return iterator.index } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[T]) Begin() { iterator.index = -1 } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[T]) End() { iterator.index = iterator.heap.Size() } // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) First() bool { +func (iterator *Iterator[T]) First() bool { iterator.Begin() return iterator.Next() } @@ -91,7 +91,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[T]) Last() bool { iterator.End() return iterator.Prev() } @@ -100,7 +100,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) NextTo(f func(index int, value T) bool) bool { for iterator.Next() { index, value := iterator.Index(), iterator.Value() if f(index, value) { @@ -114,7 +114,7 @@ func (iterator *Iterator) NextTo(f func(index int, value interface{}) bool) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's index and value can be retrieved by Index() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool { +func (iterator *Iterator[T]) PrevTo(f func(index int, value T) bool) bool { for iterator.Prev() { index, value := iterator.Index(), iterator.Value() if f(index, value) { diff --git a/trees/binaryheap/serialization.go b/trees/binaryheap/serialization.go index c1fce0a0..5881b05b 100644 --- a/trees/binaryheap/serialization.go +++ b/trees/binaryheap/serialization.go @@ -5,29 +5,29 @@ package binaryheap import ( - "github.com/emirpasic/gods/containers" + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Heap)(nil) -var _ containers.JSONDeserializer = (*Heap)(nil) +var _ containers.JSONSerializer = (*Heap[int])(nil) +var _ containers.JSONDeserializer = (*Heap[int])(nil) // ToJSON outputs the JSON representation of the heap. -func (heap *Heap) ToJSON() ([]byte, error) { +func (heap *Heap[T]) ToJSON() ([]byte, error) { return heap.list.ToJSON() } // FromJSON populates the heap from the input JSON representation. -func (heap *Heap) FromJSON(data []byte) error { +func (heap *Heap[T]) FromJSON(data []byte) error { return heap.list.FromJSON(data) } // UnmarshalJSON @implements json.Unmarshaler -func (heap *Heap) UnmarshalJSON(bytes []byte) error { +func (heap *Heap[T]) UnmarshalJSON(bytes []byte) error { return heap.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (heap *Heap) MarshalJSON() ([]byte, error) { +func (heap *Heap[T]) MarshalJSON() ([]byte, error) { return heap.ToJSON() } diff --git a/trees/btree/btree.go b/trees/btree/btree.go index eae4576b..75f60ad6 100644 --- a/trees/btree/btree.go +++ b/trees/btree/btree.go @@ -18,62 +18,59 @@ package btree import ( "bytes" + "cmp" "fmt" - "github.com/emirpasic/gods/trees" - "github.com/emirpasic/gods/utils" "strings" + + "github.com/emirpasic/gods/v2/trees" + "github.com/emirpasic/gods/v2/utils" ) // Assert Tree implementation -var _ trees.Tree = (*Tree)(nil) +var _ trees.Tree[int] = (*Tree[string, int])(nil) // Tree holds elements of the B-tree -type Tree struct { - Root *Node // Root node - Comparator utils.Comparator // Key comparator - size int // Total number of keys in the tree - m int // order (maximum number of children) +type Tree[K comparable, V any] struct { + Root *Node[K, V] // Root node + Comparator utils.Comparator[K] // Key comparator + size int // Total number of keys in the tree + m int // order (maximum number of children) } // Node is a single element within the tree -type Node struct { - Parent *Node - Entries []*Entry // Contained keys in node - Children []*Node // Children nodes +type Node[K comparable, V any] struct { + Parent *Node[K, V] + Entries []*Entry[K, V] // Contained keys in node + Children []*Node[K, V] // Children nodes } // Entry represents the key-value pair contained within nodes -type Entry struct { - Key interface{} - Value interface{} +type Entry[K comparable, V any] struct { + Key K + Value V +} + +// New instantiates a B-tree with the order (maximum number of children) and the built-in comparator for K +func New[K cmp.Ordered, V any](order int) *Tree[K, V] { + return NewWith[K, V](order, cmp.Compare[K]) } // NewWith instantiates a B-tree with the order (maximum number of children) and a custom key comparator. -func NewWith(order int, comparator utils.Comparator) *Tree { +func NewWith[K comparable, V any](order int, comparator utils.Comparator[K]) *Tree[K, V] { if order < 3 { panic("Invalid order, should be at least 3") } - return &Tree{m: order, Comparator: comparator} -} - -// NewWithIntComparator instantiates a B-tree with the order (maximum number of children) and the IntComparator, i.e. keys are of type int. -func NewWithIntComparator(order int) *Tree { - return NewWith(order, utils.IntComparator) -} - -// NewWithStringComparator instantiates a B-tree with the order (maximum number of children) and the StringComparator, i.e. keys are of type string. -func NewWithStringComparator(order int) *Tree { - return NewWith(order, utils.StringComparator) + return &Tree[K, V]{m: order, Comparator: comparator} } // Put inserts key-value pair node into the tree. // If key already exists, then its value is updated with the new value. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Put(key interface{}, value interface{}) { - entry := &Entry{Key: key, Value: value} +func (tree *Tree[K, V]) Put(key K, value V) { + entry := &Entry[K, V]{Key: key, Value: value} if tree.Root == nil { - tree.Root = &Node{Entries: []*Entry{entry}, Children: []*Node{}} + tree.Root = &Node[K, V]{Entries: []*Entry[K, V]{entry}, Children: []*Node[K, V]{}} tree.size++ return } @@ -86,24 +83,24 @@ func (tree *Tree) Put(key interface{}, value interface{}) { // Get searches the node in the tree by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Get(key interface{}) (value interface{}, found bool) { +func (tree *Tree[K, V]) Get(key K) (value V, found bool) { node, index, found := tree.searchRecursively(tree.Root, key) if found { return node.Entries[index].Value, true } - return nil, false + return value, false } // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) GetNode(key interface{}) *Node { +func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] { node, _, _ := tree.searchRecursively(tree.Root, key) return node } // Remove remove the node from the tree by key. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Remove(key interface{}) { +func (tree *Tree[K, V]) Remove(key K) { node, index, found := tree.searchRecursively(tree.Root, key) if found { tree.delete(node, index) @@ -112,18 +109,18 @@ func (tree *Tree) Remove(key interface{}) { } // Empty returns true if tree does not contain any nodes -func (tree *Tree) Empty() bool { +func (tree *Tree[K, V]) Empty() bool { return tree.size == 0 } // Size returns number of nodes in the tree. -func (tree *Tree) Size() int { +func (tree *Tree[K, V]) Size() int { return tree.size } // Size returns the number of elements stored in the subtree. // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes. -func (node *Node) Size() int { +func (node *Node[K, V]) Size() int { if node == nil { return 0 } @@ -135,8 +132,8 @@ func (node *Node) Size() int { } // Keys returns all keys in-order -func (tree *Tree) Keys() []interface{} { - keys := make([]interface{}, tree.size) +func (tree *Tree[K, V]) Keys() []K { + keys := make([]K, tree.size) it := tree.Iterator() for i := 0; it.Next(); i++ { keys[i] = it.Key() @@ -145,8 +142,8 @@ func (tree *Tree) Keys() []interface{} { } // Values returns all values in-order based on the key. -func (tree *Tree) Values() []interface{} { - values := make([]interface{}, tree.size) +func (tree *Tree[K, V]) Values() []V { + values := make([]V, tree.size) it := tree.Iterator() for i := 0; it.Next(); i++ { values[i] = it.Value() @@ -155,23 +152,23 @@ func (tree *Tree) Values() []interface{} { } // Clear removes all nodes from the tree. -func (tree *Tree) Clear() { +func (tree *Tree[K, V]) Clear() { tree.Root = nil tree.size = 0 } // Height returns the height of the tree. -func (tree *Tree) Height() int { +func (tree *Tree[K, V]) Height() int { return tree.Root.height() } // Left returns the left-most (min) node or nil if tree is empty. -func (tree *Tree) Left() *Node { +func (tree *Tree[K, V]) Left() *Node[K, V] { return tree.left(tree.Root) } // LeftKey returns the left-most (min) key or nil if tree is empty. -func (tree *Tree) LeftKey() interface{} { +func (tree *Tree[K, V]) LeftKey() interface{} { if left := tree.Left(); left != nil { return left.Entries[0].Key } @@ -179,7 +176,7 @@ func (tree *Tree) LeftKey() interface{} { } // LeftValue returns the left-most value or nil if tree is empty. -func (tree *Tree) LeftValue() interface{} { +func (tree *Tree[K, V]) LeftValue() interface{} { if left := tree.Left(); left != nil { return left.Entries[0].Value } @@ -187,12 +184,12 @@ func (tree *Tree) LeftValue() interface{} { } // Right returns the right-most (max) node or nil if tree is empty. -func (tree *Tree) Right() *Node { +func (tree *Tree[K, V]) Right() *Node[K, V] { return tree.right(tree.Root) } // RightKey returns the right-most (max) key or nil if tree is empty. -func (tree *Tree) RightKey() interface{} { +func (tree *Tree[K, V]) RightKey() interface{} { if right := tree.Right(); right != nil { return right.Entries[len(right.Entries)-1].Key } @@ -200,7 +197,7 @@ func (tree *Tree) RightKey() interface{} { } // RightValue returns the right-most value or nil if tree is empty. -func (tree *Tree) RightValue() interface{} { +func (tree *Tree[K, V]) RightValue() interface{} { if right := tree.Right(); right != nil { return right.Entries[len(right.Entries)-1].Value } @@ -208,23 +205,23 @@ func (tree *Tree) RightValue() interface{} { } // String returns a string representation of container (for debugging purposes) -func (tree *Tree) String() string { +func (tree *Tree[K, V]) String() string { var buffer bytes.Buffer buffer.WriteString("BTree\n") if !tree.Empty() { - tree.output(&buffer, tree.Root, 0, true) + tree.output(&buffer, tree.Root, 0) } return buffer.String() } -func (entry *Entry) String() string { +func (entry *Entry[K, V]) String() string { return fmt.Sprintf("%v", entry.Key) } -func (tree *Tree) output(buffer *bytes.Buffer, node *Node, level int, isTail bool) { +func (tree *Tree[K, V]) output(buffer *bytes.Buffer, node *Node[K, V], level int) { for e := 0; e < len(node.Entries)+1; e++ { if e < len(node.Children) { - tree.output(buffer, node.Children[e], level+1, true) + tree.output(buffer, node.Children[e], level+1) } if e < len(node.Entries) { buffer.WriteString(strings.Repeat(" ", level)) @@ -233,7 +230,7 @@ func (tree *Tree) output(buffer *bytes.Buffer, node *Node, level int, isTail boo } } -func (node *Node) height() int { +func (node *Node[K, V]) height() int { height := 0 for ; node != nil; node = node.Children[0] { height++ @@ -244,40 +241,40 @@ func (node *Node) height() int { return height } -func (tree *Tree) isLeaf(node *Node) bool { +func (tree *Tree[K, V]) isLeaf(node *Node[K, V]) bool { return len(node.Children) == 0 } -func (tree *Tree) isFull(node *Node) bool { +func (tree *Tree[K, V]) isFull(node *Node[K, V]) bool { return len(node.Entries) == tree.maxEntries() } -func (tree *Tree) shouldSplit(node *Node) bool { +func (tree *Tree[K, V]) shouldSplit(node *Node[K, V]) bool { return len(node.Entries) > tree.maxEntries() } -func (tree *Tree) maxChildren() int { +func (tree *Tree[K, V]) maxChildren() int { return tree.m } -func (tree *Tree) minChildren() int { +func (tree *Tree[K, V]) minChildren() int { return (tree.m + 1) / 2 // ceil(m/2) } -func (tree *Tree) maxEntries() int { +func (tree *Tree[K, V]) maxEntries() int { return tree.maxChildren() - 1 } -func (tree *Tree) minEntries() int { +func (tree *Tree[K, V]) minEntries() int { return tree.minChildren() - 1 } -func (tree *Tree) middle() int { +func (tree *Tree[K, V]) middle() int { return (tree.m - 1) / 2 // "-1" to favor right nodes to have more keys when splitting } // search searches only within the single node among its entries -func (tree *Tree) search(node *Node, key interface{}) (index int, found bool) { +func (tree *Tree[K, V]) search(node *Node[K, V], key K) (index int, found bool) { low, high := 0, len(node.Entries)-1 var mid int for low <= high { @@ -296,7 +293,7 @@ func (tree *Tree) search(node *Node, key interface{}) (index int, found bool) { } // searchRecursively searches recursively down the tree starting at the startNode -func (tree *Tree) searchRecursively(startNode *Node, key interface{}) (node *Node, index int, found bool) { +func (tree *Tree[K, V]) searchRecursively(startNode *Node[K, V], key K) (node *Node[K, V], index int, found bool) { if tree.Empty() { return nil, -1, false } @@ -313,14 +310,14 @@ func (tree *Tree) searchRecursively(startNode *Node, key interface{}) (node *Nod } } -func (tree *Tree) insert(node *Node, entry *Entry) (inserted bool) { +func (tree *Tree[K, V]) insert(node *Node[K, V], entry *Entry[K, V]) (inserted bool) { if tree.isLeaf(node) { return tree.insertIntoLeaf(node, entry) } return tree.insertIntoInternal(node, entry) } -func (tree *Tree) insertIntoLeaf(node *Node, entry *Entry) (inserted bool) { +func (tree *Tree[K, V]) insertIntoLeaf(node *Node[K, V], entry *Entry[K, V]) (inserted bool) { insertPosition, found := tree.search(node, entry.Key) if found { node.Entries[insertPosition] = entry @@ -334,7 +331,7 @@ func (tree *Tree) insertIntoLeaf(node *Node, entry *Entry) (inserted bool) { return true } -func (tree *Tree) insertIntoInternal(node *Node, entry *Entry) (inserted bool) { +func (tree *Tree[K, V]) insertIntoInternal(node *Node[K, V], entry *Entry[K, V]) (inserted bool) { insertPosition, found := tree.search(node, entry.Key) if found { node.Entries[insertPosition] = entry @@ -343,7 +340,7 @@ func (tree *Tree) insertIntoInternal(node *Node, entry *Entry) (inserted bool) { return tree.insert(node.Children[insertPosition], entry) } -func (tree *Tree) split(node *Node) { +func (tree *Tree[K, V]) split(node *Node[K, V]) { if !tree.shouldSplit(node) { return } @@ -356,17 +353,17 @@ func (tree *Tree) split(node *Node) { tree.splitNonRoot(node) } -func (tree *Tree) splitNonRoot(node *Node) { +func (tree *Tree[K, V]) splitNonRoot(node *Node[K, V]) { middle := tree.middle() parent := node.Parent - left := &Node{Entries: append([]*Entry(nil), node.Entries[:middle]...), Parent: parent} - right := &Node{Entries: append([]*Entry(nil), node.Entries[middle+1:]...), Parent: parent} + left := &Node[K, V]{Entries: append([]*Entry[K, V](nil), node.Entries[:middle]...), Parent: parent} + right := &Node[K, V]{Entries: append([]*Entry[K, V](nil), node.Entries[middle+1:]...), Parent: parent} // Move children from the node to be split into left and right nodes if !tree.isLeaf(node) { - left.Children = append([]*Node(nil), node.Children[:middle+1]...) - right.Children = append([]*Node(nil), node.Children[middle+1:]...) + left.Children = append([]*Node[K, V](nil), node.Children[:middle+1]...) + right.Children = append([]*Node[K, V](nil), node.Children[middle+1:]...) setParent(left.Children, left) setParent(right.Children, right) } @@ -389,24 +386,24 @@ func (tree *Tree) splitNonRoot(node *Node) { tree.split(parent) } -func (tree *Tree) splitRoot() { +func (tree *Tree[K, V]) splitRoot() { middle := tree.middle() - left := &Node{Entries: append([]*Entry(nil), tree.Root.Entries[:middle]...)} - right := &Node{Entries: append([]*Entry(nil), tree.Root.Entries[middle+1:]...)} + left := &Node[K, V]{Entries: append([]*Entry[K, V](nil), tree.Root.Entries[:middle]...)} + right := &Node[K, V]{Entries: append([]*Entry[K, V](nil), tree.Root.Entries[middle+1:]...)} // Move children from the node to be split into left and right nodes if !tree.isLeaf(tree.Root) { - left.Children = append([]*Node(nil), tree.Root.Children[:middle+1]...) - right.Children = append([]*Node(nil), tree.Root.Children[middle+1:]...) + left.Children = append([]*Node[K, V](nil), tree.Root.Children[:middle+1]...) + right.Children = append([]*Node[K, V](nil), tree.Root.Children[middle+1:]...) setParent(left.Children, left) setParent(right.Children, right) } // Root is a node with one entry and two children (left and right) - newRoot := &Node{ - Entries: []*Entry{tree.Root.Entries[middle]}, - Children: []*Node{left, right}, + newRoot := &Node[K, V]{ + Entries: []*Entry[K, V]{tree.Root.Entries[middle]}, + Children: []*Node[K, V]{left, right}, } left.Parent = newRoot @@ -414,13 +411,13 @@ func (tree *Tree) splitRoot() { tree.Root = newRoot } -func setParent(nodes []*Node, parent *Node) { +func setParent[K comparable, V any](nodes []*Node[K, V], parent *Node[K, V]) { for _, node := range nodes { node.Parent = parent } } -func (tree *Tree) left(node *Node) *Node { +func (tree *Tree[K, V]) left(node *Node[K, V]) *Node[K, V] { if tree.Empty() { return nil } @@ -433,7 +430,7 @@ func (tree *Tree) left(node *Node) *Node { } } -func (tree *Tree) right(node *Node) *Node { +func (tree *Tree[K, V]) right(node *Node[K, V]) *Node[K, V] { if tree.Empty() { return nil } @@ -448,7 +445,7 @@ func (tree *Tree) right(node *Node) *Node { // leftSibling returns the node's left sibling and child index (in parent) if it exists, otherwise (nil,-1) // key is any of keys in node (could even be deleted). -func (tree *Tree) leftSibling(node *Node, key interface{}) (*Node, int) { +func (tree *Tree[K, V]) leftSibling(node *Node[K, V], key K) (*Node[K, V], int) { if node.Parent != nil { index, _ := tree.search(node.Parent, key) index-- @@ -461,7 +458,7 @@ func (tree *Tree) leftSibling(node *Node, key interface{}) (*Node, int) { // rightSibling returns the node's right sibling and child index (in parent) if it exists, otherwise (nil,-1) // key is any of keys in node (could even be deleted). -func (tree *Tree) rightSibling(node *Node, key interface{}) (*Node, int) { +func (tree *Tree[K, V]) rightSibling(node *Node[K, V], key K) (*Node[K, V], int) { if node.Parent != nil { index, _ := tree.search(node.Parent, key) index++ @@ -474,7 +471,7 @@ func (tree *Tree) rightSibling(node *Node, key interface{}) (*Node, int) { // delete deletes an entry in node at entries' index // ref.: https://en.wikipedia.org/wiki/B-tree#Deletion -func (tree *Tree) delete(node *Node, index int) { +func (tree *Tree[K, V]) delete(node *Node[K, V], index int) { // deleting from a leaf node if tree.isLeaf(node) { deletedKey := node.Entries[index].Key @@ -497,7 +494,7 @@ func (tree *Tree) delete(node *Node, index int) { // rebalance rebalances the tree after deletion if necessary and returns true, otherwise false. // Note that we first delete the entry and then call rebalance, thus the passed deleted key as reference. -func (tree *Tree) rebalance(node *Node, deletedKey interface{}) { +func (tree *Tree[K, V]) rebalance(node *Node[K, V], deletedKey K) { // check if rebalancing is needed if node == nil || len(node.Entries) >= tree.minEntries() { return @@ -507,13 +504,13 @@ func (tree *Tree) rebalance(node *Node, deletedKey interface{}) { leftSibling, leftSiblingIndex := tree.leftSibling(node, deletedKey) if leftSibling != nil && len(leftSibling.Entries) > tree.minEntries() { // rotate right - node.Entries = append([]*Entry{node.Parent.Entries[leftSiblingIndex]}, node.Entries...) // prepend parent's separator entry to node's entries + node.Entries = append([]*Entry[K, V]{node.Parent.Entries[leftSiblingIndex]}, node.Entries...) // prepend parent's separator entry to node's entries node.Parent.Entries[leftSiblingIndex] = leftSibling.Entries[len(leftSibling.Entries)-1] tree.deleteEntry(leftSibling, len(leftSibling.Entries)-1) if !tree.isLeaf(leftSibling) { leftSiblingRightMostChild := leftSibling.Children[len(leftSibling.Children)-1] leftSiblingRightMostChild.Parent = node - node.Children = append([]*Node{leftSiblingRightMostChild}, node.Children...) + node.Children = append([]*Node[K, V]{leftSiblingRightMostChild}, node.Children...) tree.deleteChild(leftSibling, len(leftSibling.Children)-1) } return @@ -546,7 +543,7 @@ func (tree *Tree) rebalance(node *Node, deletedKey interface{}) { tree.deleteChild(node.Parent, rightSiblingIndex) } else if leftSibling != nil { // merge with left sibling - entries := append([]*Entry(nil), leftSibling.Entries...) + entries := append([]*Entry[K, V](nil), leftSibling.Entries...) entries = append(entries, node.Parent.Entries[leftSiblingIndex]) node.Entries = append(entries, node.Entries...) deletedKey = node.Parent.Entries[leftSiblingIndex].Key @@ -566,24 +563,24 @@ func (tree *Tree) rebalance(node *Node, deletedKey interface{}) { tree.rebalance(node.Parent, deletedKey) } -func (tree *Tree) prependChildren(fromNode *Node, toNode *Node) { - children := append([]*Node(nil), fromNode.Children...) +func (tree *Tree[K, V]) prependChildren(fromNode *Node[K, V], toNode *Node[K, V]) { + children := append([]*Node[K, V](nil), fromNode.Children...) toNode.Children = append(children, toNode.Children...) setParent(fromNode.Children, toNode) } -func (tree *Tree) appendChildren(fromNode *Node, toNode *Node) { +func (tree *Tree[K, V]) appendChildren(fromNode *Node[K, V], toNode *Node[K, V]) { toNode.Children = append(toNode.Children, fromNode.Children...) setParent(fromNode.Children, toNode) } -func (tree *Tree) deleteEntry(node *Node, index int) { +func (tree *Tree[K, V]) deleteEntry(node *Node[K, V], index int) { copy(node.Entries[index:], node.Entries[index+1:]) node.Entries[len(node.Entries)-1] = nil node.Entries = node.Entries[:len(node.Entries)-1] } -func (tree *Tree) deleteChild(node *Node, index int) { +func (tree *Tree[K, V]) deleteChild(node *Node[K, V], index int) { if index >= len(node.Children) { return } diff --git a/trees/btree/btree_test.go b/trees/btree/btree_test.go index 8dcb2581..c6af7c0e 100644 --- a/trees/btree/btree_test.go +++ b/trees/btree/btree_test.go @@ -6,13 +6,13 @@ package btree import ( "encoding/json" - "fmt" + "slices" "strings" "testing" ) func TestBTreeGet1(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(1, "a") tree.Put(2, "b") tree.Put(3, "c") @@ -22,7 +22,7 @@ func TestBTreeGet1(t *testing.T) { tree.Put(7, "g") tests := [][]interface{}{ - {0, nil, false}, + {0, "", false}, {1, "a", true}, {2, "b", true}, {3, "c", true}, @@ -30,18 +30,18 @@ func TestBTreeGet1(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests { - if value, found := tree.Get(test[0]); value != test[1] || found != test[2] { + if value, found := tree.Get(test[0].(int)); value != test[1] || found != test[2] { t.Errorf("Got %v,%v expected %v,%v", value, found, test[1], test[2]) } } } func TestBTreeGet2(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(7, "g") tree.Put(9, "i") tree.Put(10, "j") @@ -54,7 +54,7 @@ func TestBTreeGet2(t *testing.T) { tree.Put(1, "a") tests := [][]interface{}{ - {0, nil, false}, + {0, "", false}, {1, "a", true}, {2, "b", true}, {3, "c", true}, @@ -65,18 +65,18 @@ func TestBTreeGet2(t *testing.T) { {8, "h", true}, {9, "i", true}, {10, "j", true}, - {11, nil, false}, + {11, "", false}, } for _, test := range tests { - if value, found := tree.Get(test[0]); value != test[1] || found != test[2] { + if value, found := tree.Get(test[0].(int)); value != test[1] || found != test[2] { t.Errorf("Got %v,%v expected %v,%v", value, found, test[1], test[2]) } } } func TestBTreeGet3(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) if actualValue := tree.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) @@ -123,7 +123,7 @@ func TestBTreeGet3(t *testing.T) { func TestBTreePut1(t *testing.T) { // https://upload.wikimedia.org/wikipedia/commons/3/33/B_tree_insertion_example.png - tree := NewWithIntComparator(3) + tree := New[int, int](3) assertValidTree(t, tree, 0) tree.Put(1, 0) @@ -172,7 +172,7 @@ func TestBTreePut1(t *testing.T) { } func TestBTreePut2(t *testing.T) { - tree := NewWithIntComparator(4) + tree := New[int, int](4) assertValidTree(t, tree, 0) tree.Put(0, 0) @@ -213,7 +213,7 @@ func TestBTreePut2(t *testing.T) { func TestBTreePut3(t *testing.T) { // http://www.geeksforgeeks.org/b-tree-set-1-insert-2/ - tree := NewWithIntComparator(6) + tree := New[int, int](6) assertValidTree(t, tree, 0) tree.Put(10, 0) @@ -263,7 +263,7 @@ func TestBTreePut3(t *testing.T) { } func TestBTreePut4(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) assertValidTree(t, tree, 0) tree.Put(6, nil) @@ -358,14 +358,14 @@ func TestBTreePut4(t *testing.T) { func TestBTreeRemove1(t *testing.T) { // empty - tree := NewWithIntComparator(3) + tree := New[int, int](3) tree.Remove(1) assertValidTree(t, tree, 0) } func TestBTreeRemove2(t *testing.T) { // leaf node (no underflow) - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) @@ -380,7 +380,7 @@ func TestBTreeRemove2(t *testing.T) { func TestBTreeRemove3(t *testing.T) { // merge with right (underflow) { - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -391,7 +391,7 @@ func TestBTreeRemove3(t *testing.T) { } // merge with left (underflow) { - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -404,7 +404,7 @@ func TestBTreeRemove3(t *testing.T) { func TestBTreeRemove4(t *testing.T) { // rotate left (underflow) - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -424,7 +424,7 @@ func TestBTreeRemove4(t *testing.T) { func TestBTreeRemove5(t *testing.T) { // rotate right (underflow) - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -445,7 +445,7 @@ func TestBTreeRemove5(t *testing.T) { func TestBTreeRemove6(t *testing.T) { // root height reduction after a series of underflows on right side // use simulator: https://www.cs.usfca.edu/~galles/visualization/BTree.html - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -474,7 +474,7 @@ func TestBTreeRemove6(t *testing.T) { func TestBTreeRemove7(t *testing.T) { // root height reduction after a series of underflows on left side // use simulator: https://www.cs.usfca.edu/~galles/visualization/BTree.html - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -533,7 +533,7 @@ func TestBTreeRemove7(t *testing.T) { func TestBTreeRemove8(t *testing.T) { // use simulator: https://www.cs.usfca.edu/~galles/visualization/BTree.html - tree := NewWithIntComparator(3) + tree := New[int, *struct{}](3) tree.Put(1, nil) tree.Put(2, nil) tree.Put(3, nil) @@ -570,7 +570,7 @@ func TestBTreeRemove9(t *testing.T) { orders := []int{3, 4, 5, 6, 7, 8, 9, 10, 20, 100, 500, 1000, 5000, 10000} for _, order := range orders { - tree := NewWithIntComparator(order) + tree := New[int, int](order) { for i := 1; i <= max; i++ { @@ -611,7 +611,7 @@ func TestBTreeRemove9(t *testing.T) { } func TestBTreeHeight(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, int](3) if actualValue, expectedValue := tree.Height(), 0; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -664,7 +664,7 @@ func TestBTreeHeight(t *testing.T) { } func TestBTreeLeftAndRight(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) if actualValue := tree.Left(); actualValue != nil { t.Errorf("Got %v expected %v", actualValue, nil) @@ -698,7 +698,7 @@ func TestBTreeLeftAndRight(t *testing.T) { } func TestBTreeIteratorValuesAndKeys(t *testing.T) { - tree := NewWithIntComparator(4) + tree := New[int, string](4) tree.Put(4, "d") tree.Put(5, "e") tree.Put(6, "f") @@ -707,10 +707,10 @@ func TestBTreeIteratorValuesAndKeys(t *testing.T) { tree.Put(7, "g") tree.Put(2, "b") tree.Put(1, "x") // override - if actualValue, expectedValue := fmt.Sprintf("%d%d%d%d%d%d%d", tree.Keys()...), "1234567"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{1, 2, 3, 4, 5, 6, 7}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s%s%s%s", tree.Values()...), "xbcdefg"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{"x", "b", "c", "d", "e", "f", "g"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue := tree.Size(); actualValue != 7 { @@ -719,7 +719,7 @@ func TestBTreeIteratorValuesAndKeys(t *testing.T) { } func TestBTreeIteratorNextOnEmpty(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) it := tree.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty tree") @@ -727,7 +727,7 @@ func TestBTreeIteratorNextOnEmpty(t *testing.T) { } func TestBTreeIteratorPrevOnEmpty(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) it := tree.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty tree") @@ -735,7 +735,7 @@ func TestBTreeIteratorPrevOnEmpty(t *testing.T) { } func TestBTreeIterator1Next(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -759,7 +759,7 @@ func TestBTreeIterator1Next(t *testing.T) { } func TestBTreeIterator1Prev(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -785,7 +785,7 @@ func TestBTreeIterator1Prev(t *testing.T) { } func TestBTreeIterator2Next(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -804,7 +804,7 @@ func TestBTreeIterator2Next(t *testing.T) { } func TestBTreeIterator2Prev(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -825,7 +825,7 @@ func TestBTreeIterator2Prev(t *testing.T) { } func TestBTreeIterator3Next(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(1, "a") it := tree.Iterator() count := 0 @@ -842,7 +842,7 @@ func TestBTreeIterator3Next(t *testing.T) { } func TestBTreeIterator3Prev(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(1, "a") it := tree.Iterator() for it.Next() { @@ -861,7 +861,7 @@ func TestBTreeIterator3Prev(t *testing.T) { } func TestBTreeIterator4Next(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, int](3) tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -887,7 +887,7 @@ func TestBTreeIterator4Next(t *testing.T) { } func TestBTreeIterator4Prev(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, int](3) tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -915,7 +915,7 @@ func TestBTreeIterator4Prev(t *testing.T) { } func TestBTreeIteratorBegin(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -947,7 +947,7 @@ func TestBTreeIteratorBegin(t *testing.T) { } func TestBTreeIteratorEnd(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) it := tree.Iterator() if it.node != nil { @@ -974,7 +974,7 @@ func TestBTreeIteratorEnd(t *testing.T) { } func TestBTreeIteratorFirst(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -988,7 +988,7 @@ func TestBTreeIteratorFirst(t *testing.T) { } func TestBTreeIteratorLast(t *testing.T) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -1003,13 +1003,13 @@ func TestBTreeIteratorLast(t *testing.T) { func TestBTreeSearch(t *testing.T) { { - tree := NewWithIntComparator(3) - tree.Root = &Node{Entries: []*Entry{}, Children: make([]*Node, 0)} + tree := New[int, int](3) + tree.Root = &Node[int, int]{Entries: []*Entry[int, int]{}, Children: make([]*Node[int, int], 0)} tests := [][]interface{}{ {0, 0, false}, } for _, test := range tests { - index, found := tree.search(tree.Root, test[0]) + index, found := tree.search(tree.Root, test[0].(int)) if actualValue, expectedValue := index, test[1]; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -1019,8 +1019,8 @@ func TestBTreeSearch(t *testing.T) { } } { - tree := NewWithIntComparator(3) - tree.Root = &Node{Entries: []*Entry{{2, 0}, {4, 1}, {6, 2}}, Children: []*Node{}} + tree := New[int, int](3) + tree.Root = &Node[int, int]{Entries: []*Entry[int, int]{{2, 0}, {4, 1}, {6, 2}}, Children: []*Node[int, int]{}} tests := [][]interface{}{ {0, 0, false}, {1, 0, false}, @@ -1032,7 +1032,7 @@ func TestBTreeSearch(t *testing.T) { {7, 3, false}, } for _, test := range tests { - index, found := tree.search(tree.Root, test[0]) + index, found := tree.search(tree.Root, test[0].(int)) if actualValue, expectedValue := index, test[1]; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -1043,13 +1043,13 @@ func TestBTreeSearch(t *testing.T) { } } -func assertValidTree(t *testing.T, tree *Tree, expectedSize int) { +func assertValidTree[K comparable, V any](t *testing.T, tree *Tree[K, V], expectedSize int) { if actualValue, expectedValue := tree.size, expectedSize; actualValue != expectedValue { t.Errorf("Got %v expected %v for tree size", actualValue, expectedValue) } } -func assertValidTreeNode(t *testing.T, node *Node, expectedEntries int, expectedChildren int, keys []int, hasParent bool) { +func assertValidTreeNode[K comparable, V any](t *testing.T, node *Node[K, V], expectedEntries int, expectedChildren int, keys []K, hasParent bool) { if actualValue, expectedValue := node.Parent != nil, hasParent; actualValue != expectedValue { t.Errorf("Got %v expected %v for hasParent", actualValue, expectedValue) } @@ -1068,13 +1068,13 @@ func assertValidTreeNode(t *testing.T, node *Node, expectedEntries int, expected func TestBTreeIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) it := tree.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") @@ -1083,7 +1083,7 @@ func TestBTreeIteratorNextTo(t *testing.T) { // NextTo (not found) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -1094,7 +1094,7 @@ func TestBTreeIteratorNextTo(t *testing.T) { // NextTo (found) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -1103,13 +1103,13 @@ func TestBTreeIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -1120,13 +1120,13 @@ func TestBTreeIteratorNextTo(t *testing.T) { func TestBTreeIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) it := tree.Iterator() it.End() for it.PrevTo(seek) { @@ -1136,7 +1136,7 @@ func TestBTreeIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -1148,7 +1148,7 @@ func TestBTreeIteratorPrevTo(t *testing.T) { // PrevTo (found) { - tree := NewWithIntComparator(3) + tree := New[int, string](3) tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -1157,13 +1157,13 @@ func TestBTreeIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -1173,7 +1173,7 @@ func TestBTreeIteratorPrevTo(t *testing.T) { } func TestBTreeSerialization(t *testing.T) { - tree := NewWithStringComparator(3) + tree := New[string, string](3) tree.Put("c", "3") tree.Put("b", "2") tree.Put("a", "1") @@ -1183,11 +1183,11 @@ func TestBTreeSerialization(t *testing.T) { if actualValue, expectedValue := tree.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Keys(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { - t.Errorf("Got %v expected %v", actualValue, "[a,b,c]") + if actualValue, expectedValue := tree.Keys(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Values(); actualValue[0].(string) != "1" || actualValue[1].(string) != "2" || actualValue[2].(string) != "3" { - t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") + if actualValue, expectedValue := tree.Values(), []string{"1", "2", "3"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if err != nil { t.Errorf("Got error %v", err) @@ -1207,21 +1207,31 @@ func TestBTreeSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`{"a":1,"b":2}`), &tree) + intTree := New[string, int](3) + err = json.Unmarshal([]byte(`{"a":1,"b":2}`), intTree) if err != nil { t.Errorf("Got error %v", err) } + if actualValue, expectedValue := intTree.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Keys(), []string{"a", "b"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Values(), []int{1, 2}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } } func TestBTreeString(t *testing.T) { - c := NewWithStringComparator(3) + c := New[string, int](3) c.Put("a", 1) if !strings.HasPrefix(c.String(), "BTree") { t.Errorf("String should start with container name") } } -func benchmarkGet(b *testing.B, tree *Tree, size int) { +func benchmarkGet(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Get(n) @@ -1229,7 +1239,7 @@ func benchmarkGet(b *testing.B, tree *Tree, size int) { } } -func benchmarkPut(b *testing.B, tree *Tree, size int) { +func benchmarkPut(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Put(n, struct{}{}) @@ -1237,7 +1247,7 @@ func benchmarkPut(b *testing.B, tree *Tree, size int) { } } -func benchmarkRemove(b *testing.B, tree *Tree, size int) { +func benchmarkRemove(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Remove(n) @@ -1248,7 +1258,7 @@ func benchmarkRemove(b *testing.B, tree *Tree, size int) { func BenchmarkBTreeGet100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1259,7 +1269,7 @@ func BenchmarkBTreeGet100(b *testing.B) { func BenchmarkBTreeGet1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1270,7 +1280,7 @@ func BenchmarkBTreeGet1000(b *testing.B) { func BenchmarkBTreeGet10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1281,7 +1291,7 @@ func BenchmarkBTreeGet10000(b *testing.B) { func BenchmarkBTreeGet100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1292,7 +1302,7 @@ func BenchmarkBTreeGet100000(b *testing.B) { func BenchmarkBTreePut100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) b.StartTimer() benchmarkPut(b, tree, size) } @@ -1300,7 +1310,7 @@ func BenchmarkBTreePut100(b *testing.B) { func BenchmarkBTreePut1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1311,7 +1321,7 @@ func BenchmarkBTreePut1000(b *testing.B) { func BenchmarkBTreePut10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1322,7 +1332,7 @@ func BenchmarkBTreePut10000(b *testing.B) { func BenchmarkBTreePut100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1333,7 +1343,7 @@ func BenchmarkBTreePut100000(b *testing.B) { func BenchmarkBTreeRemove100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1344,7 +1354,7 @@ func BenchmarkBTreeRemove100(b *testing.B) { func BenchmarkBTreeRemove1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1355,7 +1365,7 @@ func BenchmarkBTreeRemove1000(b *testing.B) { func BenchmarkBTreeRemove10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -1366,7 +1376,7 @@ func BenchmarkBTreeRemove10000(b *testing.B) { func BenchmarkBTreeRemove100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator(128) + tree := New[int, struct{}](128) for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } diff --git a/trees/btree/iterator.go b/trees/btree/iterator.go index fb20955a..23c9387d 100644 --- a/trees/btree/iterator.go +++ b/trees/btree/iterator.go @@ -4,16 +4,16 @@ package btree -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - tree *Tree - node *Node - entry *Entry +type Iterator[K comparable, V any] struct { + tree *Tree[K, V] + node *Node[K, V] + entry *Entry[K, V] position position } @@ -24,15 +24,15 @@ const ( ) // Iterator returns a stateful iterator whose elements are key/value pairs. -func (tree *Tree) Iterator() Iterator { - return Iterator{tree: tree, node: nil, position: begin} +func (tree *Tree[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{tree: tree, node: nil, position: begin} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { // If already at end, go to end if iterator.position == end { goto end @@ -91,7 +91,7 @@ between: // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { // If already at beginning, go to begin if iterator.position == begin { goto begin @@ -149,25 +149,25 @@ between: // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[K, V]) Value() V { return iterator.entry.Value } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() K { return iterator.entry.Key } // Node returns the current element's node. // Does not modify the state of the iterator. -func (iterator *Iterator) Node() *Node { +func (iterator *Iterator[K, V]) Node() *Node[K, V] { return iterator.node } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.node = nil iterator.position = begin iterator.entry = nil @@ -175,7 +175,7 @@ func (iterator *Iterator) Begin() { // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.node = nil iterator.position = end iterator.entry = nil @@ -184,7 +184,7 @@ func (iterator *Iterator) End() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { iterator.Begin() return iterator.Next() } @@ -192,7 +192,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { iterator.End() return iterator.Prev() } @@ -201,7 +201,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -215,7 +215,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/trees/btree/serialization.go b/trees/btree/serialization.go index 460f6e0b..4b9a3139 100644 --- a/trees/btree/serialization.go +++ b/trees/btree/serialization.go @@ -6,43 +6,46 @@ package btree import ( "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Tree)(nil) -var _ containers.JSONDeserializer = (*Tree)(nil) +var _ containers.JSONSerializer = (*Tree[string, int])(nil) +var _ containers.JSONDeserializer = (*Tree[string, int])(nil) // ToJSON outputs the JSON representation of the tree. -func (tree *Tree) ToJSON() ([]byte, error) { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) ToJSON() ([]byte, error) { + elements := make(map[K]V) it := tree.Iterator() for it.Next() { - elements[utils.ToString(it.Key())] = it.Value() + elements[it.Key()] = it.Value() } return json.Marshal(&elements) } // FromJSON populates the tree from the input JSON representation. -func (tree *Tree) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) FromJSON(data []byte) error { + elements := make(map[K]V) err := json.Unmarshal(data, &elements) - if err == nil { - tree.Clear() - for key, value := range elements { - tree.Put(key, value) - } + if err != nil { + return err + } + + tree.Clear() + for key, value := range elements { + tree.Put(key, value) } + return err } // UnmarshalJSON @implements json.Unmarshaler -func (tree *Tree) UnmarshalJSON(bytes []byte) error { +func (tree *Tree[K, V]) UnmarshalJSON(bytes []byte) error { return tree.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (tree *Tree) MarshalJSON() ([]byte, error) { +func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) { return tree.ToJSON() } diff --git a/trees/redblacktree/iterator.go b/trees/redblacktree/iterator.go index e39da7d4..9aa5605e 100644 --- a/trees/redblacktree/iterator.go +++ b/trees/redblacktree/iterator.go @@ -4,15 +4,15 @@ package redblacktree -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Assert Iterator implementation -var _ containers.ReverseIteratorWithKey = (*Iterator)(nil) +var _ containers.ReverseIteratorWithKey[string, int] = (*Iterator[string, int])(nil) // Iterator holding the iterator's state -type Iterator struct { - tree *Tree - node *Node +type Iterator[K comparable, V any] struct { + tree *Tree[K, V] + node *Node[K, V] position position } @@ -23,20 +23,20 @@ const ( ) // Iterator returns a stateful iterator whose elements are key/value pairs. -func (tree *Tree) Iterator() Iterator { - return Iterator{tree: tree, node: nil, position: begin} +func (tree *Tree[K, V]) Iterator() *Iterator[K, V] { + return &Iterator[K, V]{tree: tree, node: nil, position: begin} } // IteratorAt returns a stateful iterator whose elements are key/value pairs that is initialised at a particular node. -func (tree *Tree) IteratorAt(node *Node) Iterator { - return Iterator{tree: tree, node: node, position: between} +func (tree *Tree[K, V]) IteratorAt(node *Node[K, V]) *Iterator[K, V] { + return &Iterator[K, V]{tree: tree, node: node, position: between} } // Next moves the iterator to the next element and returns true if there was a next element in the container. // If Next() returns true, then next element's key and value can be retrieved by Key() and Value(). // If Next() was called for the first time, then it will point the iterator to the first element if it exists. // Modifies the state of the iterator. -func (iterator *Iterator) Next() bool { +func (iterator *Iterator[K, V]) Next() bool { if iterator.position == end { goto end } @@ -76,7 +76,7 @@ between: // Prev moves the iterator to the previous element and returns true if there was a previous element in the container. // If Prev() returns true, then previous element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Prev() bool { +func (iterator *Iterator[K, V]) Prev() bool { if iterator.position == begin { goto begin } @@ -115,32 +115,32 @@ between: // Value returns the current element's value. // Does not modify the state of the iterator. -func (iterator *Iterator) Value() interface{} { +func (iterator *Iterator[K, V]) Value() V { return iterator.node.Value } // Key returns the current element's key. // Does not modify the state of the iterator. -func (iterator *Iterator) Key() interface{} { +func (iterator *Iterator[K, V]) Key() K { return iterator.node.Key } // Node returns the current element's node. // Does not modify the state of the iterator. -func (iterator *Iterator) Node() *Node { +func (iterator *Iterator[K, V]) Node() *Node[K, V] { return iterator.node } // Begin resets the iterator to its initial state (one-before-first) // Call Next() to fetch the first element if any. -func (iterator *Iterator) Begin() { +func (iterator *Iterator[K, V]) Begin() { iterator.node = nil iterator.position = begin } // End moves the iterator past the last element (one-past-the-end). // Call Prev() to fetch the last element if any. -func (iterator *Iterator) End() { +func (iterator *Iterator[K, V]) End() { iterator.node = nil iterator.position = end } @@ -148,7 +148,7 @@ func (iterator *Iterator) End() { // First moves the iterator to the first element and returns true if there was a first element in the container. // If First() returns true, then first element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator -func (iterator *Iterator) First() bool { +func (iterator *Iterator[K, V]) First() bool { iterator.Begin() return iterator.Next() } @@ -156,7 +156,7 @@ func (iterator *Iterator) First() bool { // Last moves the iterator to the last element and returns true if there was a last element in the container. // If Last() returns true, then last element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) Last() bool { +func (iterator *Iterator[K, V]) Last() bool { iterator.End() return iterator.Prev() } @@ -165,7 +165,7 @@ func (iterator *Iterator) Last() bool { // passed function, and returns true if there was a next element in the container. // If NextTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) NextTo(f func(key K, value V) bool) bool { for iterator.Next() { key, value := iterator.Key(), iterator.Value() if f(key, value) { @@ -179,7 +179,7 @@ func (iterator *Iterator) NextTo(f func(key interface{}, value interface{}) bool // passed function, and returns true if there was a next element in the container. // If PrevTo() returns true, then next element's key and value can be retrieved by Key() and Value(). // Modifies the state of the iterator. -func (iterator *Iterator) PrevTo(f func(key interface{}, value interface{}) bool) bool { +func (iterator *Iterator[K, V]) PrevTo(f func(key K, value V) bool) bool { for iterator.Prev() { key, value := iterator.Key(), iterator.Value() if f(key, value) { diff --git a/trees/redblacktree/redblacktree.go b/trees/redblacktree/redblacktree.go index b335e3df..ec53532d 100644 --- a/trees/redblacktree/redblacktree.go +++ b/trees/redblacktree/redblacktree.go @@ -12,13 +12,15 @@ package redblacktree import ( + "cmp" "fmt" - "github.com/emirpasic/gods/trees" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/trees" + "github.com/emirpasic/gods/v2/utils" ) // Assert Tree implementation -var _ trees.Tree = (*Tree)(nil) +var _ trees.Tree[int] = (*Tree[string, int])(nil) type color bool @@ -27,45 +29,40 @@ const ( ) // Tree holds elements of the red-black tree -type Tree struct { - Root *Node +type Tree[K comparable, V any] struct { + Root *Node[K, V] size int - Comparator utils.Comparator + Comparator utils.Comparator[K] } // Node is a single element within the tree -type Node struct { - Key interface{} - Value interface{} +type Node[K comparable, V any] struct { + Key K + Value V color color - Left *Node - Right *Node - Parent *Node -} - -// NewWith instantiates a red-black tree with the custom comparator. -func NewWith(comparator utils.Comparator) *Tree { - return &Tree{Comparator: comparator} + Left *Node[K, V] + Right *Node[K, V] + Parent *Node[K, V] } -// NewWithIntComparator instantiates a red-black tree with the IntComparator, i.e. keys are of type int. -func NewWithIntComparator() *Tree { - return &Tree{Comparator: utils.IntComparator} +// New instantiates a red-black tree with the built-in comparator for K +func New[K cmp.Ordered, V any]() *Tree[K, V] { + return &Tree[K, V]{Comparator: cmp.Compare[K]} } -// NewWithStringComparator instantiates a red-black tree with the StringComparator, i.e. keys are of type string. -func NewWithStringComparator() *Tree { - return &Tree{Comparator: utils.StringComparator} +// NewWith instantiates a red-black tree with the custom comparator. +func NewWith[K comparable, V any](comparator utils.Comparator[K]) *Tree[K, V] { + return &Tree[K, V]{Comparator: comparator} } // Put inserts node into the tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Put(key interface{}, value interface{}) { - var insertedNode *Node +func (tree *Tree[K, V]) Put(key K, value V) { + var insertedNode *Node[K, V] if tree.Root == nil { // Assert key is of comparator's type for initial tree tree.Comparator(key, key) - tree.Root = &Node{Key: key, Value: value, color: red} + tree.Root = &Node[K, V]{Key: key, Value: value, color: red} insertedNode = tree.Root } else { node := tree.Root @@ -79,7 +76,7 @@ func (tree *Tree) Put(key interface{}, value interface{}) { return case compare < 0: if node.Left == nil { - node.Left = &Node{Key: key, Value: value, color: red} + node.Left = &Node[K, V]{Key: key, Value: value, color: red} insertedNode = node.Left loop = false } else { @@ -87,7 +84,7 @@ func (tree *Tree) Put(key interface{}, value interface{}) { } case compare > 0: if node.Right == nil { - node.Right = &Node{Key: key, Value: value, color: red} + node.Right = &Node[K, V]{Key: key, Value: value, color: red} insertedNode = node.Right loop = false } else { @@ -104,24 +101,24 @@ func (tree *Tree) Put(key interface{}, value interface{}) { // Get searches the node in the tree by key and returns its value or nil if key is not found in tree. // Second return parameter is true if key was found, otherwise false. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Get(key interface{}) (value interface{}, found bool) { +func (tree *Tree[K, V]) Get(key K) (value V, found bool) { node := tree.lookup(key) if node != nil { return node.Value, true } - return nil, false + return value, false } // GetNode searches the node in the tree by key and returns its node or nil if key is not found in tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) GetNode(key interface{}) *Node { +func (tree *Tree[K, V]) GetNode(key K) *Node[K, V] { return tree.lookup(key) } // Remove remove the node from the tree by key. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Remove(key interface{}) { - var child *Node +func (tree *Tree[K, V]) Remove(key K) { + var child *Node[K, V] node := tree.lookup(key) if node == nil { return @@ -151,18 +148,18 @@ func (tree *Tree) Remove(key interface{}) { } // Empty returns true if tree does not contain any nodes -func (tree *Tree) Empty() bool { +func (tree *Tree[K, V]) Empty() bool { return tree.size == 0 } // Size returns number of nodes in the tree. -func (tree *Tree) Size() int { +func (tree *Tree[K, V]) Size() int { return tree.size } // Size returns the number of elements stored in the subtree. // Computed dynamically on each call, i.e. the subtree is traversed to count the number of the nodes. -func (node *Node) Size() int { +func (node *Node[K, V]) Size() int { if node == nil { return 0 } @@ -177,8 +174,8 @@ func (node *Node) Size() int { } // Keys returns all keys in-order -func (tree *Tree) Keys() []interface{} { - keys := make([]interface{}, tree.size) +func (tree *Tree[K, V]) Keys() []K { + keys := make([]K, tree.size) it := tree.Iterator() for i := 0; it.Next(); i++ { keys[i] = it.Key() @@ -187,8 +184,8 @@ func (tree *Tree) Keys() []interface{} { } // Values returns all values in-order based on the key. -func (tree *Tree) Values() []interface{} { - values := make([]interface{}, tree.size) +func (tree *Tree[K, V]) Values() []V { + values := make([]V, tree.size) it := tree.Iterator() for i := 0; it.Next(); i++ { values[i] = it.Value() @@ -197,8 +194,8 @@ func (tree *Tree) Values() []interface{} { } // Left returns the left-most (min) node or nil if tree is empty. -func (tree *Tree) Left() *Node { - var parent *Node +func (tree *Tree[K, V]) Left() *Node[K, V] { + var parent *Node[K, V] current := tree.Root for current != nil { parent = current @@ -208,8 +205,8 @@ func (tree *Tree) Left() *Node { } // Right returns the right-most (max) node or nil if tree is empty. -func (tree *Tree) Right() *Node { - var parent *Node +func (tree *Tree[K, V]) Right() *Node[K, V] { + var parent *Node[K, V] current := tree.Root for current != nil { parent = current @@ -226,7 +223,7 @@ func (tree *Tree) Right() *Node { // all nodes in the tree are larger than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Floor(key interface{}) (floor *Node, found bool) { +func (tree *Tree[K, V]) Floor(key K) (floor *Node[K, V], found bool) { found = false node := tree.Root for node != nil { @@ -255,7 +252,7 @@ func (tree *Tree) Floor(key interface{}) (floor *Node, found bool) { // all nodes in the tree are smaller than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (tree *Tree) Ceiling(key interface{}) (ceiling *Node, found bool) { +func (tree *Tree[K, V]) Ceiling(key K) (ceiling *Node[K, V], found bool) { found = false node := tree.Root for node != nil { @@ -277,13 +274,13 @@ func (tree *Tree) Ceiling(key interface{}) (ceiling *Node, found bool) { } // Clear removes all nodes from the tree. -func (tree *Tree) Clear() { +func (tree *Tree[K, V]) Clear() { tree.Root = nil tree.size = 0 } // String returns a string representation of container -func (tree *Tree) String() string { +func (tree *Tree[K, V]) String() string { str := "RedBlackTree\n" if !tree.Empty() { output(tree.Root, "", true, &str) @@ -291,11 +288,11 @@ func (tree *Tree) String() string { return str } -func (node *Node) String() string { +func (node *Node[K, V]) String() string { return fmt.Sprintf("%v", node.Key) } -func output(node *Node, prefix string, isTail bool, str *string) { +func output[K comparable, V any](node *Node[K, V], prefix string, isTail bool, str *string) { if node.Right != nil { newPrefix := prefix if isTail { @@ -323,7 +320,7 @@ func output(node *Node, prefix string, isTail bool, str *string) { } } -func (tree *Tree) lookup(key interface{}) *Node { +func (tree *Tree[K, V]) lookup(key K) *Node[K, V] { node := tree.Root for node != nil { compare := tree.Comparator(key, node.Key) @@ -339,21 +336,21 @@ func (tree *Tree) lookup(key interface{}) *Node { return nil } -func (node *Node) grandparent() *Node { +func (node *Node[K, V]) grandparent() *Node[K, V] { if node != nil && node.Parent != nil { return node.Parent.Parent } return nil } -func (node *Node) uncle() *Node { +func (node *Node[K, V]) uncle() *Node[K, V] { if node == nil || node.Parent == nil || node.Parent.Parent == nil { return nil } return node.Parent.sibling() } -func (node *Node) sibling() *Node { +func (node *Node[K, V]) sibling() *Node[K, V] { if node == nil || node.Parent == nil { return nil } @@ -363,7 +360,7 @@ func (node *Node) sibling() *Node { return node.Parent.Left } -func (tree *Tree) rotateLeft(node *Node) { +func (tree *Tree[K, V]) rotateLeft(node *Node[K, V]) { right := node.Right tree.replaceNode(node, right) node.Right = right.Left @@ -374,7 +371,7 @@ func (tree *Tree) rotateLeft(node *Node) { node.Parent = right } -func (tree *Tree) rotateRight(node *Node) { +func (tree *Tree[K, V]) rotateRight(node *Node[K, V]) { left := node.Left tree.replaceNode(node, left) node.Left = left.Right @@ -385,7 +382,7 @@ func (tree *Tree) rotateRight(node *Node) { node.Parent = left } -func (tree *Tree) replaceNode(old *Node, new *Node) { +func (tree *Tree[K, V]) replaceNode(old *Node[K, V], new *Node[K, V]) { if old.Parent == nil { tree.Root = new } else { @@ -400,7 +397,7 @@ func (tree *Tree) replaceNode(old *Node, new *Node) { } } -func (tree *Tree) insertCase1(node *Node) { +func (tree *Tree[K, V]) insertCase1(node *Node[K, V]) { if node.Parent == nil { node.color = black } else { @@ -408,14 +405,14 @@ func (tree *Tree) insertCase1(node *Node) { } } -func (tree *Tree) insertCase2(node *Node) { +func (tree *Tree[K, V]) insertCase2(node *Node[K, V]) { if nodeColor(node.Parent) == black { return } tree.insertCase3(node) } -func (tree *Tree) insertCase3(node *Node) { +func (tree *Tree[K, V]) insertCase3(node *Node[K, V]) { uncle := node.uncle() if nodeColor(uncle) == red { node.Parent.color = black @@ -427,7 +424,7 @@ func (tree *Tree) insertCase3(node *Node) { } } -func (tree *Tree) insertCase4(node *Node) { +func (tree *Tree[K, V]) insertCase4(node *Node[K, V]) { grandparent := node.grandparent() if node == node.Parent.Right && node.Parent == grandparent.Left { tree.rotateLeft(node.Parent) @@ -439,7 +436,7 @@ func (tree *Tree) insertCase4(node *Node) { tree.insertCase5(node) } -func (tree *Tree) insertCase5(node *Node) { +func (tree *Tree[K, V]) insertCase5(node *Node[K, V]) { node.Parent.color = black grandparent := node.grandparent() grandparent.color = red @@ -450,7 +447,7 @@ func (tree *Tree) insertCase5(node *Node) { } } -func (node *Node) maximumNode() *Node { +func (node *Node[K, V]) maximumNode() *Node[K, V] { if node == nil { return nil } @@ -460,14 +457,14 @@ func (node *Node) maximumNode() *Node { return node } -func (tree *Tree) deleteCase1(node *Node) { +func (tree *Tree[K, V]) deleteCase1(node *Node[K, V]) { if node.Parent == nil { return } tree.deleteCase2(node) } -func (tree *Tree) deleteCase2(node *Node) { +func (tree *Tree[K, V]) deleteCase2(node *Node[K, V]) { sibling := node.sibling() if nodeColor(sibling) == red { node.Parent.color = red @@ -481,7 +478,7 @@ func (tree *Tree) deleteCase2(node *Node) { tree.deleteCase3(node) } -func (tree *Tree) deleteCase3(node *Node) { +func (tree *Tree[K, V]) deleteCase3(node *Node[K, V]) { sibling := node.sibling() if nodeColor(node.Parent) == black && nodeColor(sibling) == black && @@ -494,7 +491,7 @@ func (tree *Tree) deleteCase3(node *Node) { } } -func (tree *Tree) deleteCase4(node *Node) { +func (tree *Tree[K, V]) deleteCase4(node *Node[K, V]) { sibling := node.sibling() if nodeColor(node.Parent) == red && nodeColor(sibling) == black && @@ -507,7 +504,7 @@ func (tree *Tree) deleteCase4(node *Node) { } } -func (tree *Tree) deleteCase5(node *Node) { +func (tree *Tree[K, V]) deleteCase5(node *Node[K, V]) { sibling := node.sibling() if node == node.Parent.Left && nodeColor(sibling) == black && @@ -527,7 +524,7 @@ func (tree *Tree) deleteCase5(node *Node) { tree.deleteCase6(node) } -func (tree *Tree) deleteCase6(node *Node) { +func (tree *Tree[K, V]) deleteCase6(node *Node[K, V]) { sibling := node.sibling() sibling.color = nodeColor(node.Parent) node.Parent.color = black @@ -540,7 +537,7 @@ func (tree *Tree) deleteCase6(node *Node) { } } -func nodeColor(node *Node) color { +func nodeColor[K comparable, V any](node *Node[K, V]) color { if node == nil { return black } diff --git a/trees/redblacktree/redblacktree_test.go b/trees/redblacktree/redblacktree_test.go index 4c0f519a..b4e76b1e 100644 --- a/trees/redblacktree/redblacktree_test.go +++ b/trees/redblacktree/redblacktree_test.go @@ -7,13 +7,13 @@ package redblacktree import ( "encoding/json" "fmt" - "github.com/emirpasic/gods/utils" + "slices" "strings" "testing" ) func TestRedBlackTreeGet(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() if actualValue := tree.Size(); actualValue != 0 { t.Errorf("Got %v expected %v", actualValue, 0) @@ -59,7 +59,7 @@ func TestRedBlackTreeGet(t *testing.T) { } func TestRedBlackTreePut(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -72,10 +72,10 @@ func TestRedBlackTreePut(t *testing.T) { if actualValue := tree.Size(); actualValue != 7 { t.Errorf("Got %v expected %v", actualValue, 7) } - if actualValue, expectedValue := fmt.Sprintf("%d%d%d%d%d%d%d", tree.Keys()...), "1234567"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{1, 2, 3, 4, 5, 6, 7}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s%s%s%s", tree.Values()...), "abcdefg"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{"a", "b", "c", "d", "e", "f", "g"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -87,12 +87,12 @@ func TestRedBlackTreePut(t *testing.T) { {5, "e", true}, {6, "f", true}, {7, "g", true}, - {8, nil, false}, + {8, "", false}, } for _, test := range tests1 { // retrievals - actualValue, actualFound := tree.Get(test[0]) + actualValue, actualFound := tree.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -100,7 +100,7 @@ func TestRedBlackTreePut(t *testing.T) { } func TestRedBlackTreeRemove(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -116,13 +116,10 @@ func TestRedBlackTreeRemove(t *testing.T) { tree.Remove(8) tree.Remove(5) - if actualValue, expectedValue := fmt.Sprintf("%d%d%d%d", tree.Keys()...), "1234"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{1, 2, 3, 4}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", tree.Values()...), "abcd"; actualValue != expectedValue { - t.Errorf("Got %v expected %v", actualValue, expectedValue) - } - if actualValue, expectedValue := fmt.Sprintf("%s%s%s%s", tree.Values()...), "abcd"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{"a", "b", "c", "d"}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if actualValue := tree.Size(); actualValue != 4 { @@ -134,14 +131,14 @@ func TestRedBlackTreeRemove(t *testing.T) { {2, "b", true}, {3, "c", true}, {4, "d", true}, - {5, nil, false}, - {6, nil, false}, - {7, nil, false}, - {8, nil, false}, + {5, "", false}, + {6, "", false}, + {7, "", false}, + {8, "", false}, } for _, test := range tests2 { - actualValue, actualFound := tree.Get(test[0]) + actualValue, actualFound := tree.Get(test[0].(int)) if actualValue != test[1] || actualFound != test[2] { t.Errorf("Got %v expected %v", actualValue, test[1]) } @@ -154,10 +151,10 @@ func TestRedBlackTreeRemove(t *testing.T) { tree.Remove(2) tree.Remove(2) - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Keys()), "[]"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Keys(), []int{}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Values()), "[]"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Values(), []string{}; !slices.Equal(actualValue, expectedValue) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } if empty, size := tree.Empty(), tree.Size(); empty != true || size != -0 { @@ -167,7 +164,7 @@ func TestRedBlackTreeRemove(t *testing.T) { } func TestRedBlackTreeLeftAndRight(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() if actualValue := tree.Left(); actualValue != nil { t.Errorf("Got %v expected %v", actualValue, nil) @@ -185,23 +182,23 @@ func TestRedBlackTreeLeftAndRight(t *testing.T) { tree.Put(1, "x") // overwrite tree.Put(2, "b") - if actualValue, expectedValue := fmt.Sprintf("%d", tree.Left().Key), "1"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Left().Key, 1; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Left().Value), "x"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Left().Value, "x"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%d", tree.Right().Key), "7"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Right().Key, 7; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue, expectedValue := fmt.Sprintf("%s", tree.Right().Value), "g"; actualValue != expectedValue { + if actualValue, expectedValue := tree.Right().Value, "g"; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } } func TestRedBlackTreeCeilingAndFloor(t *testing.T) { - tree := NewWith(utils.IntComparator) + tree := New[int, string]() if node, found := tree.Floor(0); node != nil || found { t.Errorf("Got %v expected %v", node, "") @@ -234,7 +231,7 @@ func TestRedBlackTreeCeilingAndFloor(t *testing.T) { } func TestRedBlackTreeIteratorNextOnEmpty(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.Next() { t.Errorf("Shouldn't iterate on empty tree") @@ -242,7 +239,7 @@ func TestRedBlackTreeIteratorNextOnEmpty(t *testing.T) { } func TestRedBlackTreeIteratorPrevOnEmpty(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.Prev() { t.Errorf("Shouldn't iterate on empty tree") @@ -250,7 +247,7 @@ func TestRedBlackTreeIteratorPrevOnEmpty(t *testing.T) { } func TestRedBlackTreeIterator1Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -281,7 +278,7 @@ func TestRedBlackTreeIterator1Next(t *testing.T) { } func TestRedBlackTreeIterator1Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(5, "e") tree.Put(6, "f") tree.Put(7, "g") @@ -314,7 +311,7 @@ func TestRedBlackTreeIterator1Prev(t *testing.T) { } func TestRedBlackTreeIterator2Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -333,7 +330,7 @@ func TestRedBlackTreeIterator2Next(t *testing.T) { } func TestRedBlackTreeIterator2Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -354,7 +351,7 @@ func TestRedBlackTreeIterator2Prev(t *testing.T) { } func TestRedBlackTreeIterator3Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(1, "a") it := tree.Iterator() count := 0 @@ -371,7 +368,7 @@ func TestRedBlackTreeIterator3Next(t *testing.T) { } func TestRedBlackTreeIterator3Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(1, "a") it := tree.Iterator() for it.Next() { @@ -390,7 +387,7 @@ func TestRedBlackTreeIterator3Prev(t *testing.T) { } func TestRedBlackTreeIterator4Next(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, int]() tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -426,7 +423,7 @@ func TestRedBlackTreeIterator4Next(t *testing.T) { } func TestRedBlackTreeIterator4Prev(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, int]() tree.Put(13, 5) tree.Put(8, 3) tree.Put(17, 7) @@ -464,7 +461,7 @@ func TestRedBlackTreeIterator4Prev(t *testing.T) { } func TestRedBlackTreeIteratorBegin(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -496,7 +493,7 @@ func TestRedBlackTreeIteratorBegin(t *testing.T) { } func TestRedBlackTreeIteratorEnd(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() if it.node != nil { @@ -523,7 +520,7 @@ func TestRedBlackTreeIteratorEnd(t *testing.T) { } func TestRedBlackTreeIteratorFirst(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -537,7 +534,7 @@ func TestRedBlackTreeIteratorFirst(t *testing.T) { } func TestRedBlackTreeIteratorLast(t *testing.T) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(3, "c") tree.Put(1, "a") tree.Put(2, "b") @@ -552,13 +549,13 @@ func TestRedBlackTreeIteratorLast(t *testing.T) { func TestRedBlackTreeIteratorNextTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // NextTo (empty) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() for it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") @@ -567,7 +564,7 @@ func TestRedBlackTreeIteratorNextTo(t *testing.T) { // NextTo (not found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -578,7 +575,7 @@ func TestRedBlackTreeIteratorNextTo(t *testing.T) { // NextTo (found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -587,13 +584,13 @@ func TestRedBlackTreeIteratorNextTo(t *testing.T) { if !it.NextTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Next() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 2 || value.(string) != "cc" { + if index, value := it.Key(), it.Value(); index != 2 || value != "cc" { t.Errorf("Got %v,%v expected %v,%v", index, value, 2, "cc") } if it.Next() { @@ -604,13 +601,13 @@ func TestRedBlackTreeIteratorNextTo(t *testing.T) { func TestRedBlackTreeIteratorPrevTo(t *testing.T) { // Sample seek function, i.e. string starting with "b" - seek := func(index interface{}, value interface{}) bool { - return strings.HasSuffix(value.(string), "b") + seek := func(index int, value string) bool { + return strings.HasSuffix(value, "b") } // PrevTo (empty) { - tree := NewWithIntComparator() + tree := New[int, string]() it := tree.Iterator() it.End() for it.PrevTo(seek) { @@ -620,7 +617,7 @@ func TestRedBlackTreeIteratorPrevTo(t *testing.T) { // PrevTo (not found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(0, "xx") tree.Put(1, "yy") it := tree.Iterator() @@ -632,7 +629,7 @@ func TestRedBlackTreeIteratorPrevTo(t *testing.T) { // PrevTo (found) { - tree := NewWithIntComparator() + tree := New[int, string]() tree.Put(2, "cc") tree.Put(0, "aa") tree.Put(1, "bb") @@ -641,13 +638,13 @@ func TestRedBlackTreeIteratorPrevTo(t *testing.T) { if !it.PrevTo(seek) { t.Errorf("Shouldn't iterate on empty tree") } - if index, value := it.Key(), it.Value(); index != 1 || value.(string) != "bb" { + if index, value := it.Key(), it.Value(); index != 1 || value != "bb" { t.Errorf("Got %v,%v expected %v,%v", index, value, 1, "bb") } if !it.Prev() { t.Errorf("Should go to first element") } - if index, value := it.Key(), it.Value(); index != 0 || value.(string) != "aa" { + if index, value := it.Key(), it.Value(); index != 0 || value != "aa" { t.Errorf("Got %v,%v expected %v,%v", index, value, 0, "aa") } if it.Prev() { @@ -657,7 +654,7 @@ func TestRedBlackTreeIteratorPrevTo(t *testing.T) { } func TestRedBlackTreeSerialization(t *testing.T) { - tree := NewWithStringComparator() + tree := New[string, string]() tree.Put("c", "3") tree.Put("b", "2") tree.Put("a", "1") @@ -667,11 +664,11 @@ func TestRedBlackTreeSerialization(t *testing.T) { if actualValue, expectedValue := tree.Size(), 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Keys(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { - t.Errorf("Got %v expected %v", actualValue, "[a,b,c]") + if actualValue, expectedValue := tree.Keys(), []string{"a", "b", "c"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if actualValue := tree.Values(); actualValue[0].(string) != "1" || actualValue[1].(string) != "2" || actualValue[2].(string) != "3" { - t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") + if actualValue, expectedValue := tree.Values(), []string{"1", "2", "3"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) } if err != nil { t.Errorf("Got error %v", err) @@ -691,21 +688,31 @@ func TestRedBlackTreeSerialization(t *testing.T) { t.Errorf("Got error %v", err) } - err = json.Unmarshal([]byte(`{"a":1,"b":2}`), &tree) + intTree := New[string, int]() + err = json.Unmarshal([]byte(`{"a":1,"b":2}`), intTree) if err != nil { t.Errorf("Got error %v", err) } + if actualValue, expectedValue := intTree.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Keys(), []string{"a", "b"}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := intTree.Values(), []int{1, 2}; !slices.Equal(actualValue, expectedValue) { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } } func TestRedBlackTreeString(t *testing.T) { - c := NewWithStringComparator() + c := New[string, int]() c.Put("a", 1) if !strings.HasPrefix(c.String(), "RedBlackTree") { t.Errorf("String should start with container name") } } -func benchmarkGet(b *testing.B, tree *Tree, size int) { +func benchmarkGet(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Get(n) @@ -713,7 +720,7 @@ func benchmarkGet(b *testing.B, tree *Tree, size int) { } } -func benchmarkPut(b *testing.B, tree *Tree, size int) { +func benchmarkPut(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Put(n, struct{}{}) @@ -721,7 +728,7 @@ func benchmarkPut(b *testing.B, tree *Tree, size int) { } } -func benchmarkRemove(b *testing.B, tree *Tree, size int) { +func benchmarkRemove(b *testing.B, tree *Tree[int, struct{}], size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { tree.Remove(n) @@ -732,7 +739,7 @@ func benchmarkRemove(b *testing.B, tree *Tree, size int) { func BenchmarkRedBlackTreeGet100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -743,7 +750,7 @@ func BenchmarkRedBlackTreeGet100(b *testing.B) { func BenchmarkRedBlackTreeGet1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -754,7 +761,7 @@ func BenchmarkRedBlackTreeGet1000(b *testing.B) { func BenchmarkRedBlackTreeGet10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -765,7 +772,7 @@ func BenchmarkRedBlackTreeGet10000(b *testing.B) { func BenchmarkRedBlackTreeGet100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -776,7 +783,7 @@ func BenchmarkRedBlackTreeGet100000(b *testing.B) { func BenchmarkRedBlackTreePut100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() b.StartTimer() benchmarkPut(b, tree, size) } @@ -784,7 +791,7 @@ func BenchmarkRedBlackTreePut100(b *testing.B) { func BenchmarkRedBlackTreePut1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -795,7 +802,7 @@ func BenchmarkRedBlackTreePut1000(b *testing.B) { func BenchmarkRedBlackTreePut10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -806,7 +813,7 @@ func BenchmarkRedBlackTreePut10000(b *testing.B) { func BenchmarkRedBlackTreePut100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -817,7 +824,7 @@ func BenchmarkRedBlackTreePut100000(b *testing.B) { func BenchmarkRedBlackTreeRemove100(b *testing.B) { b.StopTimer() size := 100 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -828,7 +835,7 @@ func BenchmarkRedBlackTreeRemove100(b *testing.B) { func BenchmarkRedBlackTreeRemove1000(b *testing.B) { b.StopTimer() size := 1000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -839,7 +846,7 @@ func BenchmarkRedBlackTreeRemove1000(b *testing.B) { func BenchmarkRedBlackTreeRemove10000(b *testing.B) { b.StopTimer() size := 10000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } @@ -850,7 +857,7 @@ func BenchmarkRedBlackTreeRemove10000(b *testing.B) { func BenchmarkRedBlackTreeRemove100000(b *testing.B) { b.StopTimer() size := 100000 - tree := NewWithIntComparator() + tree := New[int, struct{}]() for n := 0; n < size; n++ { tree.Put(n, struct{}{}) } diff --git a/trees/redblacktree/serialization.go b/trees/redblacktree/serialization.go index 9f2a23c0..9311c897 100644 --- a/trees/redblacktree/serialization.go +++ b/trees/redblacktree/serialization.go @@ -6,27 +6,27 @@ package redblacktree import ( "encoding/json" - "github.com/emirpasic/gods/containers" - "github.com/emirpasic/gods/utils" + + "github.com/emirpasic/gods/v2/containers" ) // Assert Serialization implementation -var _ containers.JSONSerializer = (*Tree)(nil) -var _ containers.JSONDeserializer = (*Tree)(nil) +var _ containers.JSONSerializer = (*Tree[string, int])(nil) +var _ containers.JSONDeserializer = (*Tree[string, int])(nil) // ToJSON outputs the JSON representation of the tree. -func (tree *Tree) ToJSON() ([]byte, error) { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) ToJSON() ([]byte, error) { + elements := make(map[K]V) it := tree.Iterator() for it.Next() { - elements[utils.ToString(it.Key())] = it.Value() + elements[it.Key()] = it.Value() } return json.Marshal(&elements) } // FromJSON populates the tree from the input JSON representation. -func (tree *Tree) FromJSON(data []byte) error { - elements := make(map[string]interface{}) +func (tree *Tree[K, V]) FromJSON(data []byte) error { + elements := make(map[K]V) err := json.Unmarshal(data, &elements) if err == nil { tree.Clear() @@ -38,11 +38,11 @@ func (tree *Tree) FromJSON(data []byte) error { } // UnmarshalJSON @implements json.Unmarshaler -func (tree *Tree) UnmarshalJSON(bytes []byte) error { +func (tree *Tree[K, V]) UnmarshalJSON(bytes []byte) error { return tree.FromJSON(bytes) } // MarshalJSON @implements json.Marshaler -func (tree *Tree) MarshalJSON() ([]byte, error) { +func (tree *Tree[K, V]) MarshalJSON() ([]byte, error) { return tree.ToJSON() } diff --git a/trees/trees.go b/trees/trees.go index 8d1b868f..ad544007 100644 --- a/trees/trees.go +++ b/trees/trees.go @@ -9,11 +9,11 @@ // Reference: https://en.wikipedia.org/wiki/Tree_%28data_structure%29 package trees -import "github.com/emirpasic/gods/containers" +import "github.com/emirpasic/gods/v2/containers" // Tree interface that all trees implement -type Tree interface { - containers.Container +type Tree[V any] interface { + containers.Container[V] // Empty() bool // Size() int // Clear() diff --git a/utils/comparator.go b/utils/comparator.go index 6a9afbf3..37ef49d7 100644 --- a/utils/comparator.go +++ b/utils/comparator.go @@ -6,244 +6,14 @@ package utils import "time" -// Comparator will make type assertion (see IntComparator for example), -// which will panic if a or b are not of the asserted type. -// -// Should return a number: -// negative , if a < b -// zero , if a == b -// positive , if a > b -type Comparator func(a, b interface{}) int - -// StringComparator provides a fast comparison on strings -func StringComparator(a, b interface{}) int { - s1 := a.(string) - s2 := b.(string) - min := len(s2) - if len(s1) < len(s2) { - min = len(s1) - } - diff := 0 - for i := 0; i < min && diff == 0; i++ { - diff = int(s1[i]) - int(s2[i]) - } - if diff == 0 { - diff = len(s1) - len(s2) - } - if diff < 0 { - return -1 - } - if diff > 0 { - return 1 - } - return 0 -} - -// IntComparator provides a basic comparison on int -func IntComparator(a, b interface{}) int { - aAsserted := a.(int) - bAsserted := b.(int) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Int8Comparator provides a basic comparison on int8 -func Int8Comparator(a, b interface{}) int { - aAsserted := a.(int8) - bAsserted := b.(int8) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Int16Comparator provides a basic comparison on int16 -func Int16Comparator(a, b interface{}) int { - aAsserted := a.(int16) - bAsserted := b.(int16) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Int32Comparator provides a basic comparison on int32 -func Int32Comparator(a, b interface{}) int { - aAsserted := a.(int32) - bAsserted := b.(int32) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Int64Comparator provides a basic comparison on int64 -func Int64Comparator(a, b interface{}) int { - aAsserted := a.(int64) - bAsserted := b.(int64) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// UIntComparator provides a basic comparison on uint -func UIntComparator(a, b interface{}) int { - aAsserted := a.(uint) - bAsserted := b.(uint) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// UInt8Comparator provides a basic comparison on uint8 -func UInt8Comparator(a, b interface{}) int { - aAsserted := a.(uint8) - bAsserted := b.(uint8) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// UInt16Comparator provides a basic comparison on uint16 -func UInt16Comparator(a, b interface{}) int { - aAsserted := a.(uint16) - bAsserted := b.(uint16) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// UInt32Comparator provides a basic comparison on uint32 -func UInt32Comparator(a, b interface{}) int { - aAsserted := a.(uint32) - bAsserted := b.(uint32) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// UInt64Comparator provides a basic comparison on uint64 -func UInt64Comparator(a, b interface{}) int { - aAsserted := a.(uint64) - bAsserted := b.(uint64) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Float32Comparator provides a basic comparison on float32 -func Float32Comparator(a, b interface{}) int { - aAsserted := a.(float32) - bAsserted := b.(float32) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// Float64Comparator provides a basic comparison on float64 -func Float64Comparator(a, b interface{}) int { - aAsserted := a.(float64) - bAsserted := b.(float64) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// ByteComparator provides a basic comparison on byte -func ByteComparator(a, b interface{}) int { - aAsserted := a.(byte) - bAsserted := b.(byte) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} - -// RuneComparator provides a basic comparison on rune -func RuneComparator(a, b interface{}) int { - aAsserted := a.(rune) - bAsserted := b.(rune) - switch { - case aAsserted > bAsserted: - return 1 - case aAsserted < bAsserted: - return -1 - default: - return 0 - } -} +type Comparator[T any] func(x, y T) int // TimeComparator provides a basic comparison on time.Time -func TimeComparator(a, b interface{}) int { - aAsserted := a.(time.Time) - bAsserted := b.(time.Time) - +func TimeComparator(a, b time.Time) int { switch { - case aAsserted.After(bAsserted): + case a.After(b): return 1 - case aAsserted.Before(bAsserted): + case a.Before(b): return -1 default: return 0 diff --git a/utils/comparator_test.go b/utils/comparator_test.go index 356c5e26..b985f94c 100644 --- a/utils/comparator_test.go +++ b/utils/comparator_test.go @@ -9,51 +9,6 @@ import ( "time" ) -func TestIntComparator(t *testing.T) { - - // i1,i2,expected - tests := [][]interface{}{ - {1, 1, 0}, - {1, 2, -1}, - {2, 1, 1}, - {11, 22, -1}, - {0, 0, 0}, - {1, 0, 1}, - {0, 1, -1}, - } - - for _, test := range tests { - actual := IntComparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestStringComparator(t *testing.T) { - - // s1,s2,expected - tests := [][]interface{}{ - {"a", "a", 0}, - {"a", "b", -1}, - {"b", "a", 1}, - {"aa", "aab", -1}, - {"", "", 0}, - {"a", "", 1}, - {"", "a", -1}, - {"", "aaaaaaa", -1}, - } - - for _, test := range tests { - actual := StringComparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - func TestTimeComparator(t *testing.T) { now := time.Now() @@ -66,239 +21,7 @@ func TestTimeComparator(t *testing.T) { } for _, test := range tests { - actual := TimeComparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestCustomComparator(t *testing.T) { - - type Custom struct { - id int - name string - } - - byID := func(a, b interface{}) int { - c1 := a.(Custom) - c2 := b.(Custom) - switch { - case c1.id > c2.id: - return 1 - case c1.id < c2.id: - return -1 - default: - return 0 - } - } - - // o1,o2,expected - tests := [][]interface{}{ - {Custom{1, "a"}, Custom{1, "a"}, 0}, - {Custom{1, "a"}, Custom{2, "b"}, -1}, - {Custom{2, "b"}, Custom{1, "a"}, 1}, - {Custom{1, "a"}, Custom{1, "b"}, 0}, - } - - for _, test := range tests { - actual := byID(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestInt8ComparatorComparator(t *testing.T) { - tests := [][]interface{}{ - {int8(1), int8(1), 0}, - {int8(0), int8(1), -1}, - {int8(1), int8(0), 1}, - } - for _, test := range tests { - actual := Int8Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestInt16Comparator(t *testing.T) { - tests := [][]interface{}{ - {int16(1), int16(1), 0}, - {int16(0), int16(1), -1}, - {int16(1), int16(0), 1}, - } - for _, test := range tests { - actual := Int16Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestInt32Comparator(t *testing.T) { - tests := [][]interface{}{ - {int32(1), int32(1), 0}, - {int32(0), int32(1), -1}, - {int32(1), int32(0), 1}, - } - for _, test := range tests { - actual := Int32Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestInt64Comparator(t *testing.T) { - tests := [][]interface{}{ - {int64(1), int64(1), 0}, - {int64(0), int64(1), -1}, - {int64(1), int64(0), 1}, - } - for _, test := range tests { - actual := Int64Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestUIntComparator(t *testing.T) { - tests := [][]interface{}{ - {uint(1), uint(1), 0}, - {uint(0), uint(1), -1}, - {uint(1), uint(0), 1}, - } - for _, test := range tests { - actual := UIntComparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestUInt8Comparator(t *testing.T) { - tests := [][]interface{}{ - {uint8(1), uint8(1), 0}, - {uint8(0), uint8(1), -1}, - {uint8(1), uint8(0), 1}, - } - for _, test := range tests { - actual := UInt8Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestUInt16Comparator(t *testing.T) { - tests := [][]interface{}{ - {uint16(1), uint16(1), 0}, - {uint16(0), uint16(1), -1}, - {uint16(1), uint16(0), 1}, - } - for _, test := range tests { - actual := UInt16Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestUInt32Comparator(t *testing.T) { - tests := [][]interface{}{ - {uint32(1), uint32(1), 0}, - {uint32(0), uint32(1), -1}, - {uint32(1), uint32(0), 1}, - } - for _, test := range tests { - actual := UInt32Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestUInt64Comparator(t *testing.T) { - tests := [][]interface{}{ - {uint64(1), uint64(1), 0}, - {uint64(0), uint64(1), -1}, - {uint64(1), uint64(0), 1}, - } - for _, test := range tests { - actual := UInt64Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestFloat32Comparator(t *testing.T) { - tests := [][]interface{}{ - {float32(1.1), float32(1.1), 0}, - {float32(0.1), float32(1.1), -1}, - {float32(1.1), float32(0.1), 1}, - } - for _, test := range tests { - actual := Float32Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestFloat64Comparator(t *testing.T) { - tests := [][]interface{}{ - {float64(1.1), float64(1.1), 0}, - {float64(0.1), float64(1.1), -1}, - {float64(1.1), float64(0.1), 1}, - } - for _, test := range tests { - actual := Float64Comparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestByteComparator(t *testing.T) { - tests := [][]interface{}{ - {byte(1), byte(1), 0}, - {byte(0), byte(1), -1}, - {byte(1), byte(0), 1}, - } - for _, test := range tests { - actual := ByteComparator(test[0], test[1]) - expected := test[2] - if actual != expected { - t.Errorf("Got %v expected %v", actual, expected) - } - } -} - -func TestRuneComparator(t *testing.T) { - tests := [][]interface{}{ - {rune(1), rune(1), 0}, - {rune(0), rune(1), -1}, - {rune(1), rune(0), 1}, - } - for _, test := range tests { - actual := RuneComparator(test[0], test[1]) + actual := TimeComparator(test[0].(time.Time), test[1].(time.Time)) expected := test[2] if actual != expected { t.Errorf("Got %v expected %v", actual, expected) diff --git a/utils/sort.go b/utils/sort.go deleted file mode 100644 index 79ced1f5..00000000 --- a/utils/sort.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2015, Emir Pasic. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package utils - -import "sort" - -// Sort sorts values (in-place) with respect to the given comparator. -// -// Uses Go's sort (hybrid of quicksort for large and then insertion sort for smaller slices). -func Sort(values []interface{}, comparator Comparator) { - sort.Sort(sortable{values, comparator}) -} - -type sortable struct { - values []interface{} - comparator Comparator -} - -func (s sortable) Len() int { - return len(s.values) -} -func (s sortable) Swap(i, j int) { - s.values[i], s.values[j] = s.values[j], s.values[i] -} -func (s sortable) Less(i, j int) bool { - return s.comparator(s.values[i], s.values[j]) < 0 -} diff --git a/utils/sort_test.go b/utils/sort_test.go deleted file mode 100644 index 7831fc9b..00000000 --- a/utils/sort_test.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2015, Emir Pasic. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package utils - -import ( - "math/rand" - "testing" -) - -func TestSortInts(t *testing.T) { - ints := []interface{}{} - ints = append(ints, 4) - ints = append(ints, 1) - ints = append(ints, 2) - ints = append(ints, 3) - - Sort(ints, IntComparator) - - for i := 1; i < len(ints); i++ { - if ints[i-1].(int) > ints[i].(int) { - t.Errorf("Not sorted!") - } - } - -} - -func TestSortStrings(t *testing.T) { - - strings := []interface{}{} - strings = append(strings, "d") - strings = append(strings, "a") - strings = append(strings, "b") - strings = append(strings, "c") - - Sort(strings, StringComparator) - - for i := 1; i < len(strings); i++ { - if strings[i-1].(string) > strings[i].(string) { - t.Errorf("Not sorted!") - } - } -} - -func TestSortStructs(t *testing.T) { - type User struct { - id int - name string - } - - byID := func(a, b interface{}) int { - c1 := a.(User) - c2 := b.(User) - switch { - case c1.id > c2.id: - return 1 - case c1.id < c2.id: - return -1 - default: - return 0 - } - } - - // o1,o2,expected - users := []interface{}{ - User{4, "d"}, - User{1, "a"}, - User{3, "c"}, - User{2, "b"}, - } - - Sort(users, byID) - - for i := 1; i < len(users); i++ { - if users[i-1].(User).id > users[i].(User).id { - t.Errorf("Not sorted!") - } - } -} - -func TestSortRandom(t *testing.T) { - ints := []interface{}{} - for i := 0; i < 10000; i++ { - ints = append(ints, rand.Int()) - } - Sort(ints, IntComparator) - for i := 1; i < len(ints); i++ { - if ints[i-1].(int) > ints[i].(int) { - t.Errorf("Not sorted!") - } - } -} - -func BenchmarkGoSortRandom(b *testing.B) { - b.StopTimer() - ints := []interface{}{} - for i := 0; i < 100000; i++ { - ints = append(ints, rand.Int()) - } - b.StartTimer() - Sort(ints, IntComparator) - b.StopTimer() -}