From 5c2f237bce8f2baa0749f722a2b277767213d5d8 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Sun, 24 Nov 2024 11:26:03 +0000 Subject: [PATCH] reworked prf as transformers, added a simple test --- pyterrier_dr/__init__.py | 4 ++-- pyterrier_dr/prf.py | 50 ++++++++++++++++++++++++++++++---------- tests/test_prf.py | 21 +++++++++++++++++ 3 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 tests/test_prf.py diff --git a/pyterrier_dr/__init__.py b/pyterrier_dr/__init__.py index 0b0d985..320146f 100644 --- a/pyterrier_dr/__init__.py +++ b/pyterrier_dr/__init__.py @@ -10,9 +10,9 @@ from pyterrier_dr.electra import ElectraScorer from pyterrier_dr.bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder from pyterrier_dr.cde import CDE, CDECache -from pyterrier_dr.prf import average_prf, vector_prf +from pyterrier_dr.prf import AveragePrf, VectorPrf __all__ = ["FlexIndex", "DocnoFile", "NilIndex", "NumpyIndex", "RankedLists", "FaissFlat", "FaissHnsw", "MemIndex", "TorchIndex", "BiEncoder", "BiQueryEncoder", "BiDocEncoder", "BiScorer", "HgfBiEncoder", "TasB", "RetroMAE", "SBertBiEncoder", "Ance", "Query2Query", "GTR", "TctColBert", "ElectraScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache", - "SimFn", "infer_device", "average_prf", "vector_prf"] + "SimFn", "infer_device", "AveragePrf", "VectorPrf"] diff --git a/pyterrier_dr/prf.py b/pyterrier_dr/prf.py index a8ea7c5..05a30d3 100644 --- a/pyterrier_dr/prf.py +++ b/pyterrier_dr/prf.py @@ -4,7 +4,7 @@ import pyterrier_alpha as pta -def vector_prf(*, alpha : float = 1, beta : float = 0.2, k : int = 3): +class VectorPrf(pt.Transformer): """ Performs a Rocchio-esque PRF by linearly combining the query_vec column with the doc_vec column of the top k documents. @@ -23,19 +23,35 @@ def vector_prf(*, alpha : float = 1, beta : float = 0.2, k : int = 3): Reference: Hang Li, Ahmed Mourad, Shengyao Zhuang, Bevan Koopman, Guido Zuccon. [Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: Successes and Pitfalls](https://arxiv.org/pdf/2108.11044.pdf) """ - def _vector_prf(inp): + def __init__(self, + *, + alpha: float = 1, + beta: float = 0.2, + k: int = 3 + ): + self.alpha = alpha + self.beta = beta + self.k = k + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + """Transforms the input DataFrame query-by-query.""" + return pt.apply.by_query(self.transform_by_query, add_ranks=False)(inp) + + def transform_by_query(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.result_frame(inp, extra_columns=['query', 'query_vec', 'doc_vec']) # get the docvectors for the top k docs - doc_vecs = np.stack([ row.doc_vec for row in inp.head(k).itertuples() ]) + doc_vecs = np.stack([ row.doc_vec for row in inp.head(self.k).itertuples() ]) # combine their average and add to the query - query_vec = alpha * inp.iloc[0]['query_vec'] + beta * np.mean(doc_vecs, axis=0) + query_vec = self.alpha * inp.iloc[0]['query_vec'] + self.beta * np.mean(doc_vecs, axis=0) # generate new query dataframe with 'qid', 'query', 'query_vec' - return pd.DataFrame([[inp.iloc[0]['qid'], inp.iloc[0]['query'], query_vec]], columns=['qid', 'query', 'query_vec']) + return pd.DataFrame([[inp['qid'].iloc[0], inp['query'].iloc[0], query_vec]], columns=['qid', 'query', 'query_vec']) + + def __repr__(self): + return f"VectorPrf(alpha={self.alpha}, beta={self.beta}, k={self.k})" - return pt.apply.by_query(_vector_prf, add_ranks=False) -def average_prf(*, k : int = 3): +class AveragePrf(pt.Transformer): """ Performs Average PRF (as described by Li et al.) by averaging the query_vec column with the doc_vec column of the top k documents. @@ -51,16 +67,26 @@ def average_prf(*, k : int = 3): prf_pipe = model >> index >> index.vec_loader() >> pyterier_dr.average_prf() >> index Reference: Hang Li, Ahmed Mourad, Shengyao Zhuang, Bevan Koopman, Guido Zuccon. [Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: Successes and Pitfalls](https://arxiv.org/pdf/2108.11044.pdf) - """ - def _average_prf(inp): + def __init__(self, + *, + k: int = 3 + ): + self.k = k + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + """Transforms the input DataFrame query-by-query.""" + return pt.apply.by_query(self.transform_by_query, add_ranks=False)(inp) + + def transform_by_query(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.result_frame(inp, extra_columns=['query_vec', 'doc_vec']) # get the docvectors for the top k docs and the query_vec - all_vecs = np.stack([inp.iloc[0]['query_vec']] + [row.doc_vec for row in inp.head(k).itertuples()]) + all_vecs = np.stack([inp.iloc[0]['query_vec']] + [row.doc_vec for row in inp.head(self.k).itertuples()]) # combine their average and add to the query query_vec = np.mean(all_vecs, axis=0) # generate new query dataframe with 'qid', 'query', 'query_vec' - return pd.DataFrame([[inp.iloc[0]['qid'], inp.iloc[0]['query'], query_vec]], columns=['qid', 'query', 'query_vec']) + return pd.DataFrame([[inp['qid'].iloc[0], inp['query'].iloc[0], query_vec]], columns=['qid', 'query', 'query_vec']) - return pt.apply.by_query(_average_prf, add_ranks=False) + def __repr__(self): + return f"AveragePrf(k={self.k})" diff --git a/tests/test_prf.py b/tests/test_prf.py new file mode 100644 index 0000000..0f03454 --- /dev/null +++ b/tests/test_prf.py @@ -0,0 +1,21 @@ +import unittest +import numpy as np +import pandas as pd +from pyterrier_dr import AveragePrf + + +class TestModels(unittest.TestCase): + + def test_avg_prf(self): + prf = AveragePrf() + inp = pd.DataFrame([['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])]], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) + out = prf(inp) + self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) + self.assertEqual(len(out), 1) + self.assertEqual(out['qid'][0], 'q1') + self.assertEqual(out['query'][0], 'query') + np.testing.assert_array_equal(out['query_vec'][0], np.array([2.5, 3.5, 4.5])) + + +if __name__ == '__main__': + unittest.main()