Skip to content

Commit

Permalink
more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 7, 2023
1 parent 52309c9 commit 9f9e7b0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ python-terrier>=0.9.1
torch
numpy
npids
sentence_transformers
16 changes: 16 additions & 0 deletions tests/test_flexindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def test_index_typical(self):
self.assertEqual(a['docno'], b['docno'])
self.assertTrue((a['doc_vec'] == b['doc_vec']).all())

# TODO: tests for:
# - corpus_graph
# - faiss_flat_retriever
# - faiss_hnsw_retriever
# - faiss_hnsw_graph
# - faiss_ivf_retriever
# - pre_ladr
# - ada_ladr
# - np_retriever
# - np_vec_loader
# - np_scorer
# - scann_retriever
# - torch_vecs
# - torch_scorer
# - torch_retriever

def setUp(self):
import pyterrier as pt
if not pt.started():
Expand Down
57 changes: 38 additions & 19 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,44 @@

class TestModels(unittest.TestCase):

def _base_test(self, model):
def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test_scorer=True, test_indexer=True, test_retriever=True):
from pyterrier_dr import FlexIndex
destdir = tempfile.mkdtemp()
self.test_dirs.append(destdir)
index = FlexIndex(destdir)
index = FlexIndex(destdir+'/index')

dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))

with self.subTest('query_encoder'):
query_res = model(dataset.get_topics())

with self.subTest('doc_encoder'):
doc_res = model(pd.DataFrame(docs))

with self.subTest('scorer'):
score_res = model(pd.DataFrame([
{'qid': '0', 'query': 'test query', 'docno': '0', 'text': 'test documemnt'},
]))

with self.subTest('indexer'):
pipeline = model >> index
pipeline.index(docs)

with self.subTest('retriever'):
retr_res = pipeline(dataset.get_topics())
if test_query_encoder:
with self.subTest('query_encoder'):
query_res = model(dataset.get_topics())
# TODO: what to assert about query_res?

if test_doc_encoder:
with self.subTest('doc_encoder'):
doc_res = model(pd.DataFrame(docs))
# TODO: what to assert about doc_res?

if test_scorer:
with self.subTest('scorer'):
# TODO: more comprehensive test case
score_res = model(pd.DataFrame([
{'qid': '0', 'query': 'test query', 'docno': '0', 'text': 'test documemnt'},
]))
# TODO: what to assert about score_res?

if test_indexer:
with self.subTest('indexer'):
pipeline = model >> index
pipeline.index(docs)
# TODO: what to assert?

if test_retriever:
with self.subTest('retriever'):
retr_res = pipeline(dataset.get_topics())
# TODO: what to assert about retr_res?

def test_tct(self):
from pyterrier_dr import TctColBert
Expand All @@ -47,6 +58,14 @@ def test_tasb(self):
from pyterrier_dr import TasB
self._base_test(TasB.dot())

def test_retromae(self):
from pyterrier_dr import RetroMAE
self._base_test(RetroMAE.msmarco_finetune())

def test_query2query(self):
from pyterrier_dr import Query2Query
self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

def setUp(self):
import pyterrier as pt
if not pt.started():
Expand Down

0 comments on commit 9f9e7b0

Please sign in to comment.