Skip to content

Commit

Permalink
better np_vec_loader test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Oct 14, 2023
1 parent ba78f88 commit 9017e96
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/test_flexindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,19 @@ def test_np_vec_loader(self):
res = vec_loader(pd.DataFrame({
'docid': [5, 1, 100, 198],
}))
self.assertTrue(all(c in res.columns) for c in ['docid', 'doc_vec'])
self.assertEqual(len(res), 4)
self.assertTrue((res.iloc[0]['doc_vec'] == dataset[5]['doc_vec']).all())
self.assertTrue((res.iloc[1]['doc_vec'] == dataset[1]['doc_vec']).all())
self.assertTrue((res.iloc[2]['doc_vec'] == dataset[100]['doc_vec']).all())
self.assertTrue((res.iloc[3]['doc_vec'] == dataset[198]['doc_vec']).all())
with self.subTest('docno'):
res = vec_loader(pd.DataFrame({
'docno': ['20', '0', '100', '198'],
'query': 'ignored',
}))
self.assertTrue(all(c in res.columns) for c in ['docno', 'doc_vec', 'query'])
self.assertEqual(len(res), 4)
self.assertTrue((res.iloc[0]['doc_vec'] == dataset[20]['doc_vec']).all())
self.assertTrue((res.iloc[1]['doc_vec'] == dataset[0]['doc_vec']).all())
self.assertTrue((res.iloc[2]['doc_vec'] == dataset[100]['doc_vec']).all())
Expand All @@ -158,7 +163,6 @@ def test_np_vec_loader(self):
# - ada_ladr
# - gar
# - np_scorer
# - scann_retriever
# - torch_vecs
# - torch_scorer

Expand Down

0 comments on commit 9017e96

Please sign in to comment.