Skip to content

Commit

Permalink
update tests to use FlexIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 7, 2023
1 parent cd94dfe commit dc72cef
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 54 deletions.
Empty file added tests/__init__.py
Empty file.
70 changes: 16 additions & 54 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,24 @@

class TestIndexingTCT(unittest.TestCase):

def _indexing_100doc(self, model : pt.Transformer, indexer_clz, dim=None):
import pyterrier_dr

# a memoryindex doesnt need a directory
if indexer_clz == pyterrier_dr.MemIndex:
index = indexer_clz()
else:
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = indexer_clz(destdir, overwrite=True)
def test_indexing_tct(self):
from pyterrier_dr import FlexIndex, TctColBert
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir, overwrite=True)

model = TctColBert()

# create an indexing pipelne
idx_pipeline = model >> index

iter = pt.get_dataset("vaswani").get_corpus_iter()
# only index 200 docs
# only index 200 docs
idx_pipeline.index([ next(iter) for i in range(200) ])

retr_pipeline = model >> index
df1 = retr_pipeline.search("analogue computer")
self.assertTrue(len(df1) > 0)


def test_indexing_tct_numpy(self):
import pyterrier_dr
self._indexing_100doc(
pyterrier_dr.TctColBert(),
pyterrier_dr.NumpyIndex
)

def test_indexing_tct_torch(self):
import torch
if not torch.cuda.is_available():
self.skipTest("no cuda available")
import pyterrier_dr
self._indexing_100doc(
pyterrier_dr.TctColBert(),
pyterrier_dr.TorchIndex
)

def test_indexing_tct_faisshnsw(self):
import pyterrier_dr
self._indexing_100doc(
pyterrier_dr.TctColBert(),
pyterrier_dr.FaissHnsw
)

def test_indexing_tct_faissflat(self):
import pyterrier_dr
self._indexing_100doc(
pyterrier_dr.TctColBert(),
pyterrier_dr.FaissFlat
)

def test_indexing_tct_mem(self):
import pyterrier_dr
self._indexing_100doc(
pyterrier_dr.TctColBert(),
pyterrier_dr.MemIndex
)

def setUp(self):
import pyterrier as pt
Expand All @@ -75,8 +33,12 @@ def setUp(self):

def tearDown(self):
import shutil
try:
for d in self.test_dirs:
for d in self.test_dirs:
try:
shutil.rmtree(d)
except:
pass
except:
pass


if __name__ == '__main__':
unittest.main()

0 comments on commit dc72cef

Please sign in to comment.