Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 14, 2023
1 parent cf5b125 commit 439615e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ jobs:
conda install pip
pip install --upgrade --upgrade-strategy eager git+https://github.com/terrier-org/pyterrier.git#egg=python-terrier
pip install --upgrade --upgrade-strategy eager -r requirements.txt
pip install --upgrade --upgrade-strategy eager -r requirements-dev.txt
#install this software
pip install --timeout=120 .
pip install --upgrade --upgrade-strategy eager pytest
- name: All unit tests
env:
Expand Down
4 changes: 2 additions & 2 deletions pyterrier_dr/flex/faiss_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def _faiss_hnsw_retriever(self, neighbours=32, ef_construction=40, ef_search=16,


def _faiss_hnsw_graph(self, neighbours=32, ef_construction=40):
key = ('faiss_hnsw', neighbours, ef_construction)
key = ('faiss_hnsw', neighbours//2, ef_construction)
graph_name = f'hnsw_n-{neighbours}_ef-{ef_construction}.graph'
if key not in self._cache:
if not (self.index_path/graph_name/'pt_meta.json').exists():
retr = self.faiss_hnsw_retriever(neighbours=neighbours, ef_construction=ef_construction)
retr = self.faiss_hnsw_retriever(neighbours=neighbours//2, ef_construction=ef_construction)
_build_hnsw_graph(retr.faiss_index.hnsw, self.index_path/graph_name)
from pyterrier_adaptive import CorpusGraph
self._cache[key] = CorpusGraph.load(self.index_path/graph_name)
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest
git+https://github.com/terrierteam/pyterrier_adaptive
41 changes: 35 additions & 6 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

class TestFlexIndex(unittest.TestCase):

def _generate_data(self, count=1000, dim=100):
return [
{'docno': str(i), 'doc_vec': np.random.rand(dim).astype(np.float32)}
for i in range(count)
]

def test_index_typical(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
Expand All @@ -15,10 +21,7 @@ def test_index_typical(self):

self.assertFalse(index.built())

dataset = [
{'docno': str(i), 'doc_vec': np.random.rand(100).astype(np.float32)}
for i in range(1000)
]
dataset = self._generate_data()

index.index(dataset)

Expand All @@ -32,11 +35,37 @@ def test_index_typical(self):
self.assertEqual(a['docno'], b['docno'])
self.assertTrue((a['doc_vec'] == b['doc_vec']).all())

def test_corpus_graph(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')

self.assertFalse(index.built())

dataset = self._generate_data()

index.index(dataset)
graph = index.corpus_graph(16)
self.assertEqual(graph.neighbours(4).shape, (16,))

def test_faiss_hnsw_graph(self):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir+'/index')

self.assertFalse(index.built())

dataset = self._generate_data()

index.index(dataset)
graph = index.faiss_hnsw_graph(16)
self.assertEqual(graph.neighbours(4).shape, (16,))

# TODO: tests for:
# - corpus_graph
# - faiss_flat_retriever
# - faiss_hnsw_retriever
# - faiss_hnsw_graph
# - faiss_ivf_retriever
# - pre_ladr
# - ada_ladr
Expand Down

0 comments on commit 439615e

Please sign in to comment.