Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan-Awad committed Sep 15, 2024
1 parent fc7ace4 commit 920abd6
Show file tree
Hide file tree
Showing 19 changed files with 203 additions and 18,719 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
run: go mod download

- name: Run tests
run: go test ./... -count=1
run: go test ./... -count=1 -v
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ require github.com/stretchr/testify v1.9.0
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
61 changes: 53 additions & 8 deletions hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package hnswgo
*/
import "C"
import (
"errors"
"fmt"
"math"
"unsafe"
Expand All @@ -15,13 +16,27 @@ import (
type Index struct {
index C.HNSW
dimensions int
size uint32
normalize bool
spaceType string
}

// Returns the last error message. Returns nil if there is no error message.
func getLastErrorMsg() string {
return C.GoString(C.getLastErrorMsg())
func peekLastError() error {
err := C.peekLastErrorMsg()
if err == nil {
return nil
}
return errors.New(C.GoString(err))
}

// Returns and clears the last error message. Returns nil if there is no error message.
func getLastError() error {
err := C.getLastErrorMsg()
if err == nil {
return nil
}
return errors.New(C.GoString(err))
}

/*
Expand Down Expand Up @@ -60,9 +75,23 @@ Returns a reference to an instance of an HNSW index.
Returns an instance of an HNSW index, or an error if there was a problem initializing the index.
*/
func New(dim int, m int, efConstruction int, randSeed int, maxElements uint32, spaceType string) (*Index, error) {
if dim < 1 {
return nil, errors.New("dimension must be >= 1")
}
if maxElements < 1 {
return nil, errors.New("max elements must be >= 1")
}
if m < 2 {
return nil, errors.New("m must be >= 2")
}
if efConstruction < 0 {
return nil, errors.New("efConstruction must be >= 0")
}

index := new(Index)
index.dimensions = dim
index.spaceType = spaceType
index.size = maxElements

if spaceType == "ip" {
index.index = C.initHNSW(C.int(dim), C.ulong(maxElements), C.int(m), C.int(efConstruction), C.int(randSeed), C.char('i'))
Expand All @@ -74,10 +103,10 @@ func New(dim int, m int, efConstruction int, randSeed int, maxElements uint32, s
}

if index.index == nil {
return nil, fmt.Errorf("failed to initialize HNSW index: %s", getLastErrorMsg())
return nil, getLastError()
}

return index, nil
return index, getLastError()
}

/*
Expand All @@ -94,11 +123,16 @@ Adds a vector to the HNSW index.
- label: the vector's label
*/
func (i *Index) InsertVector(vector []float32, label uint32) {
func (i *Index) InsertVector(vector []float32, label uint32) error {
if len(vector) != i.dimensions {
return fmt.Errorf("the vector you are trying to insert is %d-dimensional whereas your index is %d-dimensional", len(vector), i.dimensions)
}

if i.normalize {
Normalize(vector)
}
C.insertVector(i.index, (*C.float)(unsafe.Pointer(&vector[0])), C.ulong(label))
return getLastError()
}

/*
Expand All @@ -111,6 +145,13 @@ Performs similarity search on the HNSW index.
Returns the labels and distances of each of the nearest neighbors. Note: the size of both arrays can be < k if k > num of vectors in the index
*/
func (i *Index) SearchKNN(vector []float32, k int) ([]uint32, []float32, error) {
if len(vector) != i.dimensions {
return nil, nil, fmt.Errorf("the query vector is %d-dimensional whereas your index is %d-dimensional", len(vector), i.dimensions)
}
if k < 1 || uint32(k) > i.size {
return nil, nil, fmt.Errorf("1 <= k <= index max size")
}

if i.normalize {
Normalize(vector)
}
Expand All @@ -121,7 +162,7 @@ func (i *Index) SearchKNN(vector []float32, k int) ([]uint32, []float32, error)
numResult := int(C.searchKNN(i.index, (*C.float)(unsafe.Pointer(&vector[0])), C.int(k), &Clabel[0], &Cdist[0])) // perform the search

if numResult < 0 {
return nil, nil, fmt.Errorf("an error occured with the HNSW algorithm: %s", getLastErrorMsg())
return nil, nil, fmt.Errorf("an error occured with the HNSW algorithm: %s", getLastError())
}

labels := make([]uint32, k)
Expand All @@ -131,16 +172,20 @@ func (i *Index) SearchKNN(vector []float32, k int) ([]uint32, []float32, error)
dists[i] = float32(Cdist[i])
}

return labels[:numResult], dists[:numResult], nil
return labels[:numResult], dists[:numResult], getLastError()
}

/*
Set's the efConstruction parameter in the HNSW index.
- efConstruction: the new efConstruction parameter
*/
func (i *Index) SetEfConstruction(efConstruction int) {
func (i *Index) SetEfConstruction(efConstruction int) error {
if efConstruction < 0 {
return errors.New("efConstruction must be >= 0")
}
C.setEf(i.index, C.int(efConstruction))
return getLastError()
}

//func Load(location string, dim int, spaceType string) *HNSW {
Expand Down
43 changes: 28 additions & 15 deletions hnsw_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "hnswlib/hnswlib.h"
#include "lib/hnswlib/hnswlib.h"
#include "hnsw_wrapper.h"

static thread_local std::string lastErrorMsg; // stores the last error message
Expand All @@ -8,10 +8,21 @@ static thread_local std::string lastErrorMsg; // stores the last error message
*
* @return last error message or a null pointer
*/
char* getLastErrorMsg() {
char* peekLastErrorMsg() {
return lastErrorMsg.empty() ? nullptr : strdup(lastErrorMsg.c_str());
}

/**
* Returns the last error message AND clears it, if there exists one. If no error message exists, a null pointer will be returned.
*
* @return last error message or a null pointer
*/
char* getLastErrorMsg() {
char *err = peekLastErrorMsg();
lastErrorMsg.clear();
return err;
}

/**
* Instantiates and returns an HNSW index.
*
Expand All @@ -20,13 +31,11 @@ char* getLastErrorMsg() {
* @param m: `m` parameter in the HNSW algorithm
* @param efConstruction: `efConstruction` parameter in the HNSW algorithm
* @param randSeed: random seed
* @param spaceType: similarity metric to use in the index
* @param spaceType: similarity metric to use in the index ("l" = L2, "i" = IP, "c" = cosine). (default: "l")
*
* @return instance of a HNSW index
*/
HNSW initHNSW(int dim, unsigned long int maxElements, int m, int efConstruction, int randSeed, char spaceType) {
// add checks on all arguments with specific exceptions thrown

try {
hnswlib::SpaceInterface<float> *vectorSpace;
if (spaceType == 'i') { // inner product
Expand Down Expand Up @@ -64,6 +73,7 @@ void freeHNSW(HNSW hnswIndex) {
* @param label: the vector's label
*/
void insertVector(HNSW hnswIndex, float *vector, unsigned long int label) {
// ** I believe a vector is overwritten when you insert another vector with the same label **
try {
((hnswlib::HierarchicalNSW<float>*) hnswIndex)->addPoint(vector, label);
} catch(const std::runtime_error e) {
Expand All @@ -88,20 +98,23 @@ int searchKNN(HNSW hnswIndex, float *vector, int k, unsigned long int *labels, f
std::priority_queue<std::pair<float, hnswlib::labeltype>> searchResults;
try {
searchResults = ((hnswlib::HierarchicalNSW<float>*) hnswIndex)->searchKnn(vector, k);

int n = searchResults.size();
std::pair<float, hnswlib::labeltype> pair;
for (int i = n - 1; i >= 0; i--) {
pair = searchResults.top();
distances[i] = pair.first;
labels[i] = pair.second;
searchResults.pop();
}
return n;
} catch (const std::runtime_error e) {
lastErrorMsg = std::string(e.what());
return -1;
} catch (const std::exception e) {
lastErrorMsg = std::string(e.what());
return -1;
}

int n = searchResults.size();
std::pair<float, hnswlib::labeltype> pair;
for (int i = n - 1; i >= 0; i--) {
pair = searchResults.top();
distances[i] = pair.first;
labels[i] = pair.second;
searchResults.pop();
}
return n;
}

/**
Expand Down
1 change: 1 addition & 0 deletions hnsw_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
extern "C" {
#endif
typedef void* HNSW;
char* peekLastErrorMsg();
char* getLastErrorMsg();
HNSW initHNSW(int dim, unsigned long int maxElements, int m, int efConstruction, int randSeed, char simMetric);
void freeHNSW(HNSW hnswIndex);
Expand Down
Loading

0 comments on commit 920abd6

Please sign in to comment.