Skip to content

Commit

Permalink
Added more functions to SparseVector
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 21, 2024
1 parent 309b771 commit 13c9ba8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sparsevec.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func NewSparseVectorFromMap(elements map[int32]float32, dim int32) SparseVector
return SparseVector{dim: dim, indices: indices, values: values}
}

// Dimensions returns the number of dimensions.
func (v SparseVector) Dimensions() int32 {
return v.dim
}

// Indices returns the non-zero indices.
func (v SparseVector) Indices() []int32 {
return v.indices
}

// Values returns the non-zero values.
func (v SparseVector) Values() []float32 {
return v.values
}

// Slice returns a slice of float32.
func (v SparseVector) Slice() []float32 {
vec := make([]float32, v.dim)
Expand Down
28 changes: 28 additions & 0 deletions sparsevec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,41 @@ import (
"github.com/pgvector/pgvector-go"
)

func TestNewSparseVector(t *testing.T) {
vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0})
if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) {
t.Error()
}
}

func TestNewSparseVectorFromMap(t *testing.T) {
vec := pgvector.NewSparseVectorFromMap(map[int32]float32{0: 1, 2: 2, 4: 3}, 6)
if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) {
t.Error()
}
}

func TestDimensions(t *testing.T) {
vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0})
if vec.Dimensions() != 6 {
t.Error()
}
}

func TestIndices(t *testing.T) {
vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0})
if !reflect.DeepEqual(vec.Indices(), []int32{0, 2, 4}) {
t.Error()
}
}

func TestValues(t *testing.T) {
vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0})
if !reflect.DeepEqual(vec.Values(), []float32{1, 2, 3}) {
t.Error()
}
}

func TestSparseVectorSlice(t *testing.T) {
vec := pgvector.NewSparseVector([]float32{1, 0, 2, 0, 3, 0})
if !reflect.DeepEqual(vec.Slice(), []float32{1, 0, 2, 0, 3, 0}) {
Expand Down

0 comments on commit 13c9ba8

Please sign in to comment.