Skip to content

Commit

Permalink
Merge pull request #12 from Eigen-DB/handle_overwriting_of_vectors
Browse files Browse the repository at this point in the history
Control how vectors are being overwritten
  • Loading branch information
Ryan-Awad authored Oct 17, 2024
2 parents 661b05f + 7b618ef commit bc3e290
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
26 changes: 26 additions & 0 deletions hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func (i *Index) Free() {

/*
Adds a vector to the HNSW index.
If the a vector with the same label already exists, the function returns an error
- vector: the vector to add to the index
Expand All @@ -211,13 +212,38 @@ func (i *Index) InsertVector(vector []float32, label uint64) error {
return fmt.Errorf("the vector you are trying to insert is %d-dimensional whereas your index is %d-dimensional", len(vector), i.dimensions)
}

_, err := i.GetVector(label)
if err == nil {
return fmt.Errorf("a vector with label %d already exists in the index", label)
}

if i.normalize {
Normalize(vector)
}
C.insertVector(i.index, (*C.float)(unsafe.Pointer(&vector[0])), C.ulong(label))
return getLastError()
}

/*
Replaces an existing vector in the HNSW index.
- label: the vector's label
- newVector: the new vector used to replace the old vector
Returns an error if one occured.
*/
func (i *Index) ReplaceVector(label uint64, newVector []float32) error {
if len(newVector) != i.dimensions {
return fmt.Errorf("the vector you are trying to insert is %d-dimensional whereas your index is %d-dimensional", len(newVector), i.dimensions)
}
if i.normalize {
Normalize(newVector)
}
C.insertVector(i.index, (*C.float)(unsafe.Pointer(&newVector[0])), C.ulong(label))
return getLastError()
}

/*
Returns a vector's components using its label
Expand Down
5 changes: 3 additions & 2 deletions hnsw_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ HNSW loadHNSW(char *location, int dim, char spaceType, unsigned long maxElements
} else { // default: L2
vectorSpace = new hnswlib::L2Space(dim);
}
return new hnswlib::HierarchicalNSW<float>(vectorSpace, std::string(location), false, maxElements); // load the index from the specified location
return new hnswlib::HierarchicalNSW<float>(vectorSpace, std::string(location), false, maxElements, true); // load the index from the specified location
} catch (const std::runtime_error e) {
lastErrorMsg = std::string(e.what());
return nullptr;
Expand Down Expand Up @@ -112,7 +112,8 @@ void freeHNSW(HNSW hnswIndex) {
}

/**
* Adds a vector to the HNSW index. If a vector with the specified label already exists, it will be overwritten.
* Adds a vector to the HNSW index.
* NOTE: If a vector with the specified label already exists, IT WILL BE OVERWRITTEN.
*
* @param hnswIndex: HNSW index to add the point to
* @param vector: the vector to add to the index
Expand Down
52 changes: 50 additions & 2 deletions tests/hnswgo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,56 @@ func TestInsertVectorOverwrite(t *testing.T) {
t.Fatalf("An error occured when inserting a vector: %s", err.Error())
}

if err := index.InsertVector([]float32{4.2, 6.2}, 1); err != nil {
t.Fatalf("An error occured when overwritting vector: %s", err.Error())
if err := index.InsertVector([]float32{4.2, 6.2}, 1); err == nil {
t.Fatalf("No error occured when attempting to overwrite a vector")
} else if err.Error() != "a vector with label 1 already exists in the index" {
t.Fatalf("Got the wrong error when trying to overwrite a vector: %s", err.Error())
}
}

func TestReplaceVectorSuccess(t *testing.T) {
index, err := setup()
if err != nil {
t.Fatal(err.Error())
}
defer index.Free()

if err := index.InsertVector([]float32{1.2, -4.2}, 1); err != nil {
t.Fatalf("An error occured when inserting a vector: %s", err.Error())
}

if err := index.ReplaceVector(1, []float32{4.2, 6.2}); err != nil {
t.Fatalf("An error occured when trying to replace a vector: %s", err.Error())
}
}

func TestReplaceVectorFirstInsert(t *testing.T) {
index, err := setup()
if err != nil {
t.Fatal(err.Error())
}
defer index.Free()

if err := index.ReplaceVector(1, []float32{4.2, 6.2}); err != nil {
t.Fatalf("An error occured when trying to replace a non-existant vector: %s", err.Error())
}
}

func TestReplaceVectorInvalidDims(t *testing.T) {
index, err := setup()
if err != nil {
t.Fatal(err.Error())
}
defer index.Free()

if err := index.InsertVector([]float32{1.2, -4.2}, 1); err != nil {
t.Fatalf("An error occured when inserting a vector: %s", err.Error())
}

if err := index.ReplaceVector(1, []float32{4.2, 6.2, 3.3}); err == nil {
t.Fatalf("No error occured when trying to replace a vector with invalid dimensions")
} else if err.Error() != "the vector you are trying to insert is 3-dimensional whereas your index is 2-dimensional" {
t.Fatalf("Got the wrong error when trying to overwrite a vector with invalid dimensions: %s", err.Error())
}
}

Expand Down

0 comments on commit bc3e290

Please sign in to comment.