Skip to content

Commit

Permalink
add tests for bgem3
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaschari committed Nov 18, 2024
1 parent e0e479b commit 0ebcde5
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,46 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test
self.assertTrue('docno' in retr_res.columns)
self.assertTrue('score' in retr_res.columns)
self.assertTrue('rank' in retr_res.columns)

def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False):
dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))
docs_df = pd.DataFrame(docs)

if test_query_multivec_encoder:
with self.subTest('query_multivec_encoder'):
topics = dataset.get_topics()
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs_toks' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs_toks))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs_toks' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs_toks' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs_toks))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs_toks' in enc_docs_empty.columns)

def test_tct(self):
from pyterrier_dr import TctColBert
Expand All @@ -129,6 +169,16 @@ def test_query2query(self):
from pyterrier_dr import Query2Query
self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

def test_bgem3(self):
from pyterrier_dr import BGEM3
# create BGEM3 instance
bgem3 = BGEM3(max_length=1024)

self._base_test(bgem3.query_multi_encoder(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)
self._base_test(bgem3.doc_multi_encoder(), test_query_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True)
self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True)
def setUp(self):
import pyterrier as pt
if not pt.started():
Expand Down

0 comments on commit 0ebcde5

Please sign in to comment.