From 655f060eecd93688237ff6ccfd8703af7800ac62 Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Mon, 9 Dec 2024 11:42:17 +0000 Subject: [PATCH] add a unit test for compilation --- tests/test_compile.py | 55 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/test_compile.py diff --git a/tests/test_compile.py b/tests/test_compile.py new file mode 100644 index 0000000..e040dc0 --- /dev/null +++ b/tests/test_compile.py @@ -0,0 +1,55 @@ +import unittest +import tempfile +import unittest +import numpy as np +import pandas as pd +import pyterrier_dr +from pyterrier_dr import FlexIndex + +class TestFlexIndex(unittest.TestCase): + + def _generate_data(self, count=2000, dim=100): + def random_unit_vec(): + v = np.random.rand(dim).astype(np.float32) + return v / np.linalg.norm(v) + return [ + {'docno': str(i), 'doc_vec': random_unit_vec()} + for i in range(count) + ] + + def test_compilation_with_rank_and_averageprf(self): + self._test_compilation_with_rank_and_prf(pyterrier_dr.AveragePrf) + + def test_compilation_with_rank_and_vectorprf(self): + self._test_compilation_with_rank_and_prf(pyterrier_dr.VectorPrf) + + def _test_compilation_with_rank_and_prf(self, prf_clz): + + with tempfile.TemporaryDirectory() as destdir: + index = FlexIndex(destdir+'/index') + dataset = self._generate_data(count=2000) + index.index(dataset) + + retr = index.retriever() + queries = pd.DataFrame([ + {'qid': '0', 'query_vec': dataset[0]['doc_vec']}, + {'qid': '1', 'query_vec': dataset[1]['doc_vec']}, + ]) + + pipe1 = retr >> prf_clz(k=3) >> retr + pipe1_opt = pipe1.compile() + self.assertEqual(3, pipe1_opt[0].num_results) + self.assertEqual(1000, pipe1_opt[-1].num_results) + #NB: pipe1 wouldnt actually work for PRF, as doc_vecs are not present. however compilation is valid + + pipe2 = retr >> index.vec_loader() >> pyterrier_dr.AveragePrf(k=3) >> (retr % 2) + pipe2_opt = pipe2.compile() + self.assertEqual(3, pipe2_opt[0].num_results) + self.assertEqual(2, pipe2_opt[-1].num_results) + + res2 = pipe2(queries) + res2_opt = pipe2_opt(queries) + + pd.testing.assert_frame_equal(res2, res2_opt) + + \ No newline at end of file