diff --git a/requirements.txt b/requirements.txt index 55fba85..2f4ce13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ python-terrier>=0.9.1 torch numpy npids +sentence_transformers diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index 1b41422..ea79a8b 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -32,6 +32,22 @@ def test_index_typical(self): self.assertEqual(a['docno'], b['docno']) self.assertTrue((a['doc_vec'] == b['doc_vec']).all()) + # TODO: tests for: + # - corpus_graph + # - faiss_flat_retriever + # - faiss_hnsw_retriever + # - faiss_hnsw_graph + # - faiss_ivf_retriever + # - pre_ladr + # - ada_ladr + # - np_retriever + # - np_vec_loader + # - np_scorer + # - scann_retriever + # - torch_vecs + # - torch_scorer + # - torch_retriever + def setUp(self): import pyterrier as pt if not pt.started(): diff --git a/tests/test_models.py b/tests/test_models.py index 13107e7..c5c51e9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,33 +7,44 @@ class TestModels(unittest.TestCase): - def _base_test(self, model): + def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test_scorer=True, test_indexer=True, test_retriever=True): from pyterrier_dr import FlexIndex destdir = tempfile.mkdtemp() self.test_dirs.append(destdir) - index = FlexIndex(destdir) + index = FlexIndex(destdir+'/index') dataset = pt.get_dataset('irds:vaswani') docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200)) - with self.subTest('query_encoder'): - query_res = model(dataset.get_topics()) - - with self.subTest('doc_encoder'): - doc_res = model(pd.DataFrame(docs)) - - with self.subTest('scorer'): - score_res = model(pd.DataFrame([ - {'qid': '0', 'query': 'test query', 'docno': '0', 'text': 'test documemnt'}, - ])) - - with self.subTest('indexer'): - pipeline = model >> index - pipeline.index(docs) - - with self.subTest('retriever'): - retr_res = pipeline(dataset.get_topics()) + if test_query_encoder: + with self.subTest('query_encoder'): + query_res = model(dataset.get_topics()) + # TODO: what to assert about query_res? + + if test_doc_encoder: + with self.subTest('doc_encoder'): + doc_res = model(pd.DataFrame(docs)) + # TODO: what to assert about doc_res? + + if test_scorer: + with self.subTest('scorer'): + # TODO: more comprehensive test case + score_res = model(pd.DataFrame([ + {'qid': '0', 'query': 'test query', 'docno': '0', 'text': 'test documemnt'}, + ])) + # TODO: what to assert about score_res? + + if test_indexer: + with self.subTest('indexer'): + pipeline = model >> index + pipeline.index(docs) + # TODO: what to assert? + + if test_retriever: + with self.subTest('retriever'): + retr_res = pipeline(dataset.get_topics()) + # TODO: what to assert about retr_res? def test_tct(self): from pyterrier_dr import TctColBert @@ -47,6 +58,14 @@ def test_tasb(self): from pyterrier_dr import TasB self._base_test(TasB.dot()) + def test_retromae(self): + from pyterrier_dr import RetroMAE + self._base_test(RetroMAE.msmarco_finetune()) + + def test_query2query(self): + from pyterrier_dr import Query2Query + self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False) + def setUp(self): import pyterrier as pt if not pt.started():