diff --git a/pyterrier_dr/biencoder.py b/pyterrier_dr/biencoder.py index 91f3a02..99632cf 100644 --- a/pyterrier_dr/biencoder.py +++ b/pyterrier_dr/biencoder.py @@ -143,6 +143,9 @@ def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None, t def encode(self, texts, batch_size=None) -> np.array: return self.bi_encoder_model.encode_docs(texts, batch_size=batch_size or self.batch_size) + + def fuse_rank_cutoff(self, k): + return pt.RankCutoff(k) >> self def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.columns(inp, includes=[self.text_field]) diff --git a/pyterrier_dr/flex/faiss_retr.py b/pyterrier_dr/flex/faiss_retr.py index fab8c6b..f6494d3 100644 --- a/pyterrier_dr/flex/faiss_retr.py +++ b/pyterrier_dr/flex/faiss_retr.py @@ -25,6 +25,13 @@ def __init__(self, flex_index, faiss_index, n_probe=None, ef_search=None, search self.drop_query_vec = drop_query_vec self.num_results = num_results + def fuse_rank_cutoff(self, k): + return None # disable fusion for ANN + if k < self.num_results: + return FaissRetriever(self.flex_index, self.faiss_index, + n_probe=self.n_probe, ef_search=self.ef_search, search_bounded_queue=self.search_bounded_queue, + num_results=k, qbatch=self.qbatch, drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) diff --git a/pyterrier_dr/flex/flatnav_retr.py b/pyterrier_dr/flex/flatnav_retr.py index 6eaed31..0e89867 100644 --- a/pyterrier_dr/flex/flatnav_retr.py +++ b/pyterrier_dr/flex/flatnav_retr.py @@ -20,6 +20,15 @@ def __init__(self, flex_index, flatnav_index, *, threads=16, ef_search=100, num_ self.num_initializations = num_initializations self.verbose = verbose + def fuse_rank_cutoff(self, k): + return None # disable fusion for ANN + if k < self.num_results: + return FlatNavRetriever(self.flex_index, self.flatnav_index, + num_results=k, ef_search=self.ef_search, + qbatch = self.qbatch, num_initializations=self.num_initializations, + drop_query_vec=self.drop_query_vec, verbose=self.verbose) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) diff --git a/pyterrier_dr/flex/gar.py b/pyterrier_dr/flex/gar.py index 2f3c818..3849794 100644 --- a/pyterrier_dr/flex/gar.py +++ b/pyterrier_dr/flex/gar.py @@ -15,6 +15,11 @@ def __init__(self, flex_index, graph, score_fn, batch_size=128, num_results=1000 self.num_results = num_results self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + if k < self.num_results: + return FlexGar(self.flex_index, self.graph, score_fn=self.score_fn, + num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.result_frame(inp, extra_columns=['query_vec', 'score']) diff --git a/pyterrier_dr/flex/ladr.py b/pyterrier_dr/flex/ladr.py index bc79e46..8484311 100644 --- a/pyterrier_dr/flex/ladr.py +++ b/pyterrier_dr/flex/ladr.py @@ -91,6 +91,11 @@ def __init__(self, flex_index, graph, dense_scorer, num_results=1000, depth=100, self.max_hops = max_hops self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + if k < self.num_results: + return LadrAdaptive(self.flex_index, self.graph, self.dense_scorer, + num_results=k, depth=self.depth, max_hops=self.max_hops, drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.result_frame(inp, extra_columns=['query_vec']) docnos, config = self.flex_index.payload(return_dvecs=False) diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index 620b135..93f6290 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -21,6 +21,10 @@ def __init__(self, self.batch_size = batch_size or 4096 self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + if k < self.num_results: + return NumpyRetriever(self.flex_index, num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec) + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) @@ -67,7 +71,9 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: docids = self.flex_index._load_docids(inp) dvecs, config = self.flex_index.payload(return_docnos=False) return inp.assign(doc_vec=list(dvecs[docids])) - + + def fuse_rank_cutoff(self, k): + return pt.RankCutoff(k) >> self class NumpyScorer(pt.Transformer): def __init__(self, flex_index: FlexIndex, *, num_results: Optional[int] = None): diff --git a/pyterrier_dr/flex/scann_retr.py b/pyterrier_dr/flex/scann_retr.py index 025e3b2..231a349 100644 --- a/pyterrier_dr/flex/scann_retr.py +++ b/pyterrier_dr/flex/scann_retr.py @@ -20,6 +20,11 @@ def __init__(self, flex_index, scann_index, num_results=1000, leaves_to_search=N self.qbatch = qbatch self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + return None # disable fusion for ANN + if k < self.num_results: + return ScannRetriever(self.flex_index, self.scann_index, num_results=k, leaves_to_search=self.leaves_to_search, qbatch=self.qbatch, drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) diff --git a/pyterrier_dr/flex/torch_retr.py b/pyterrier_dr/flex/torch_retr.py index 589b151..674e7fb 100644 --- a/pyterrier_dr/flex/torch_retr.py +++ b/pyterrier_dr/flex/torch_retr.py @@ -51,6 +51,15 @@ def __init__(self, self.qbatch = qbatch self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + if k < self.num_results: + return TorchRetriever( + self.flex_index, + self.torch_vecs, + num_results=k, + qbatch=self.qbatch, + drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) diff --git a/pyterrier_dr/flex/voyager_retr.py b/pyterrier_dr/flex/voyager_retr.py index 785812f..970ef1d 100644 --- a/pyterrier_dr/flex/voyager_retr.py +++ b/pyterrier_dr/flex/voyager_retr.py @@ -18,6 +18,17 @@ def __init__(self, flex_index, voyager_index, query_ef=None, num_results=1000, q self.qbatch = qbatch self.drop_query_vec = drop_query_vec + def fuse_rank_cutoff(self, k): + return None # disable fusion for ANN + if k < self.num_results: + return VoyagerRetriever( + self.flex_index, + self.voyager_index, + query_ef=self.query_ef, + num_results=k, + qbatch=self.qbatch, + drop_query_vec=self.drop_query_vec) + def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) diff --git a/pyterrier_dr/prf.py b/pyterrier_dr/prf.py index 00fa87e..6f72d47 100644 --- a/pyterrier_dr/prf.py +++ b/pyterrier_dr/prf.py @@ -34,6 +34,9 @@ def __init__(self, self.beta = beta self.k = k + def compile(self) -> pt.Transformer: + return pt.RankCutoff(self.k) >> self + @pta.transform.by_query(add_ranks=False) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: """Performs Vector PRF on the input dataframe.""" @@ -76,6 +79,9 @@ def __init__(self, ): self.k = k + def compile(self) -> pt.Transformer: + return pt.RankCutoff(self.k) >> self + @pta.transform.by_query(add_ranks=False) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: """Performs Average PRF on the input dataframe.""" diff --git a/tests/test_compile.py b/tests/test_compile.py new file mode 100644 index 0000000..e040dc0 --- /dev/null +++ b/tests/test_compile.py @@ -0,0 +1,55 @@ +import unittest +import tempfile +import unittest +import numpy as np +import pandas as pd +import pyterrier_dr +from pyterrier_dr import FlexIndex + +class TestFlexIndex(unittest.TestCase): + + def _generate_data(self, count=2000, dim=100): + def random_unit_vec(): + v = np.random.rand(dim).astype(np.float32) + return v / np.linalg.norm(v) + return [ + {'docno': str(i), 'doc_vec': random_unit_vec()} + for i in range(count) + ] + + def test_compilation_with_rank_and_averageprf(self): + self._test_compilation_with_rank_and_prf(pyterrier_dr.AveragePrf) + + def test_compilation_with_rank_and_vectorprf(self): + self._test_compilation_with_rank_and_prf(pyterrier_dr.VectorPrf) + + def _test_compilation_with_rank_and_prf(self, prf_clz): + + with tempfile.TemporaryDirectory() as destdir: + index = FlexIndex(destdir+'/index') + dataset = self._generate_data(count=2000) + index.index(dataset) + + retr = index.retriever() + queries = pd.DataFrame([ + {'qid': '0', 'query_vec': dataset[0]['doc_vec']}, + {'qid': '1', 'query_vec': dataset[1]['doc_vec']}, + ]) + + pipe1 = retr >> prf_clz(k=3) >> retr + pipe1_opt = pipe1.compile() + self.assertEqual(3, pipe1_opt[0].num_results) + self.assertEqual(1000, pipe1_opt[-1].num_results) + #NB: pipe1 wouldnt actually work for PRF, as doc_vecs are not present. however compilation is valid + + pipe2 = retr >> index.vec_loader() >> pyterrier_dr.AveragePrf(k=3) >> (retr % 2) + pipe2_opt = pipe2.compile() + self.assertEqual(3, pipe2_opt[0].num_results) + self.assertEqual(2, pipe2_opt[-1].num_results) + + res2 = pipe2(queries) + res2_opt = pipe2_opt(queries) + + pd.testing.assert_frame_equal(res2, res2_opt) + + \ No newline at end of file