diff --git a/pyterrier_dr/_mmr.py b/pyterrier_dr/_mmr.py index 5931497..b9526cb 100644 --- a/pyterrier_dr/_mmr.py +++ b/pyterrier_dr/_mmr.py @@ -12,17 +12,19 @@ class MmrScorer(pt.Transformer): .. cite.dblp:: conf/sigir/CarbonellG98 """ - def __init__(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, verbose: bool = False): + def __init__(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False): """ Args: Lambda: The balance parameter between relevance and diversity (default: 0.5) norm_rel: Whether to normalize relevance scores to [0, 1] (default: False) norm_sim: Whether to normalize similarity scores to [0, 1] (default: False) + drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True) verbose: Whether to display verbose output (e.g., progress bars) (default: False) """ self.Lambda = Lambda self.norm_rel = norm_rel self.norm_sim = norm_sim + self.drop_doc_vec = drop_doc_vec self.verbose = verbose def transform(self, inp: pd.DataFrame) -> pd.DataFrame: @@ -51,10 +53,13 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: if marg_rels.shape[0] > 1: marg_rels = np.max(np.stack([marg_rels, dvec_sims[idx]]), axis=0) marg_rels[idx] = float('inf') # ignore this document from now on - out.append(frame.iloc[new_idxs].reset_index(drop=True).assign( + new_frame = frame.iloc[new_idxs].reset_index(drop=True).assign( score=-np.arange(len(new_idxs), dtype=float), rank=np.arange(len(new_idxs)) - )) + ) + if self.drop_doc_vec: + new_frame = new_frame.drop(columns='doc_vec') + out.append(new_frame) return pd.concat(out, ignore_index=True) diff --git a/pyterrier_dr/flex/diversity.py b/pyterrier_dr/flex/diversity.py index d61bec1..66cc287 100644 --- a/pyterrier_dr/flex/diversity.py +++ b/pyterrier_dr/flex/diversity.py @@ -3,7 +3,7 @@ from . import FlexIndex -def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, verbose: bool = False) -> pt.Transformer: +def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False) -> pt.Transformer: """Returns an MMR (Maximal Marginal Relevance) scorer (i.e., re-ranker) over this index. The method first loads vectors from the index and then applies :class:`MmrScorer` to re-rank the results. See @@ -13,9 +13,10 @@ def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = Lambda: The balance parameter between relevance and diversity (default: 0.5) norm_rel: Whether to normalize relevance scores to [0, 1] (default: False) norm_sim: Whether to normalize similarity scores to [0, 1] (default: False) + drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True) verbose: Whether to display verbose output (e.g., progress bars) (default: False) .. cite.dblp:: conf/sigir/CarbonellG98 """ - return self.vec_loader() >> pyterrier_dr.MmrScorer(Lambda=Lambda, norm_rel=norm_rel, norm_sim=norm_sim, verbose=verbose) + return self.vec_loader() >> pyterrier_dr.MmrScorer(Lambda=Lambda, norm_rel=norm_rel, norm_sim=norm_sim, drop_doc_vec=drop_doc_vec, verbose=verbose) FlexIndex.mmr = _mmr diff --git a/tests/test_mmr.py b/tests/test_mmr.py new file mode 100644 index 0000000..2d3cd9b --- /dev/null +++ b/tests/test_mmr.py @@ -0,0 +1,31 @@ +import unittest +import numpy as np +import pandas as pd +from pyterrier_dr import MmrScorer + + +class TestMmr(unittest.TestCase): + def test_mmr(self): + mmr = MmrScorer() + results = mmr(pd.DataFrame([ + ['q0', 'd0', 1.0, np.array([0, 1, 0])], + ['q0', 'd1', 0.5, np.array([0, 1, 1])], + ['q0', 'd2', 0.5, np.array([1, 1, 1])], + ['q0', 'd3', 0.1, np.array([1, 1, 0])], + ['q1', 'd0', 0.6, np.array([0, 1, 0])], + ['q2', 'd0', 0.4, np.array([0, 1, 0])], + ['q2', 'd1', 0.3, np.array([0, 1, 1])], + ], columns=['qid', 'docno', 'score', 'doc_vec'])) + pd.testing.assert_frame_equal(results, pd.DataFrame([ + ['q0', 'd0', 0.0, 0], + ['q0', 'd2', -1.0, 1], + ['q0', 'd1', -2.0, 2], + ['q0', 'd3', -3.0, 3], + ['q1', 'd0', 0.0, 0], + ['q2', 'd0', 0.0, 0], + ['q2', 'd1', -1.0, 1], + ], columns=['qid', 'docno', 'score', 'rank'])) + + +if __name__ == '__main__': + unittest.main()