From 9017e960a308fa48378f6d824ba2774a4a731516 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Sat, 14 Oct 2023 21:46:48 +0100 Subject: [PATCH] better np_vec_loader test cases --- tests/test_flexindex.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_flexindex.py b/tests/test_flexindex.py index 4c1496b..6adfdef 100644 --- a/tests/test_flexindex.py +++ b/tests/test_flexindex.py @@ -140,6 +140,8 @@ def test_np_vec_loader(self): res = vec_loader(pd.DataFrame({ 'docid': [5, 1, 100, 198], })) + self.assertTrue(all(c in res.columns) for c in ['docid', 'doc_vec']) + self.assertEqual(len(res), 4) self.assertTrue((res.iloc[0]['doc_vec'] == dataset[5]['doc_vec']).all()) self.assertTrue((res.iloc[1]['doc_vec'] == dataset[1]['doc_vec']).all()) self.assertTrue((res.iloc[2]['doc_vec'] == dataset[100]['doc_vec']).all()) @@ -147,7 +149,10 @@ def test_np_vec_loader(self): with self.subTest('docno'): res = vec_loader(pd.DataFrame({ 'docno': ['20', '0', '100', '198'], + 'query': 'ignored', })) + self.assertTrue(all(c in res.columns) for c in ['docno', 'doc_vec', 'query']) + self.assertEqual(len(res), 4) self.assertTrue((res.iloc[0]['doc_vec'] == dataset[20]['doc_vec']).all()) self.assertTrue((res.iloc[1]['doc_vec'] == dataset[0]['doc_vec']).all()) self.assertTrue((res.iloc[2]['doc_vec'] == dataset[100]['doc_vec']).all()) @@ -158,7 +163,6 @@ def test_np_vec_loader(self): # - ada_ladr # - gar # - np_scorer - # - scann_retriever # - torch_vecs # - torch_scorer