Skip to content

Commit

Permalink
add test for mmr
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Dec 19, 2024
1 parent e5129bf commit 5585426
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
11 changes: 8 additions & 3 deletions pyterrier_dr/_mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ class MmrScorer(pt.Transformer):
.. cite.dblp:: conf/sigir/CarbonellG98
"""
def __init__(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, verbose: bool = False):
def __init__(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False):
"""
Args:
Lambda: The balance parameter between relevance and diversity (default: 0.5)
norm_rel: Whether to normalize relevance scores to [0, 1] (default: False)
norm_sim: Whether to normalize similarity scores to [0, 1] (default: False)
drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True)
verbose: Whether to display verbose output (e.g., progress bars) (default: False)
"""
self.Lambda = Lambda
self.norm_rel = norm_rel
self.norm_sim = norm_sim
self.drop_doc_vec = drop_doc_vec
self.verbose = verbose

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -51,10 +53,13 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if marg_rels.shape[0] > 1:
marg_rels = np.max(np.stack([marg_rels, dvec_sims[idx]]), axis=0)
marg_rels[idx] = float('inf') # ignore this document from now on
out.append(frame.iloc[new_idxs].reset_index(drop=True).assign(
new_frame = frame.iloc[new_idxs].reset_index(drop=True).assign(
score=-np.arange(len(new_idxs), dtype=float),
rank=np.arange(len(new_idxs))
))
)
if self.drop_doc_vec:
new_frame = new_frame.drop(columns='doc_vec')
out.append(new_frame)

return pd.concat(out, ignore_index=True)

Expand Down
5 changes: 3 additions & 2 deletions pyterrier_dr/flex/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from . import FlexIndex


def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, verbose: bool = False) -> pt.Transformer:
def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False) -> pt.Transformer:
"""Returns an MMR (Maximal Marginal Relevance) scorer (i.e., re-ranker) over this index.
The method first loads vectors from the index and then applies :class:`MmrScorer` to re-rank the results. See
Expand All @@ -13,9 +13,10 @@ def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool =
Lambda: The balance parameter between relevance and diversity (default: 0.5)
norm_rel: Whether to normalize relevance scores to [0, 1] (default: False)
norm_sim: Whether to normalize similarity scores to [0, 1] (default: False)
drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True)
verbose: Whether to display verbose output (e.g., progress bars) (default: False)
.. cite.dblp:: conf/sigir/CarbonellG98
"""
return self.vec_loader() >> pyterrier_dr.MmrScorer(Lambda=Lambda, norm_rel=norm_rel, norm_sim=norm_sim, verbose=verbose)
return self.vec_loader() >> pyterrier_dr.MmrScorer(Lambda=Lambda, norm_rel=norm_rel, norm_sim=norm_sim, drop_doc_vec=drop_doc_vec, verbose=verbose)
FlexIndex.mmr = _mmr
31 changes: 31 additions & 0 deletions tests/test_mmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
import numpy as np
import pandas as pd
from pyterrier_dr import MmrScorer


class TestMmr(unittest.TestCase):
def test_mmr(self):
mmr = MmrScorer()
results = mmr(pd.DataFrame([
['q0', 'd0', 1.0, np.array([0, 1, 0])],
['q0', 'd1', 0.5, np.array([0, 1, 1])],
['q0', 'd2', 0.5, np.array([1, 1, 1])],
['q0', 'd3', 0.1, np.array([1, 1, 0])],
['q1', 'd0', 0.6, np.array([0, 1, 0])],
['q2', 'd0', 0.4, np.array([0, 1, 0])],
['q2', 'd1', 0.3, np.array([0, 1, 1])],
], columns=['qid', 'docno', 'score', 'doc_vec']))
pd.testing.assert_frame_equal(results, pd.DataFrame([
['q0', 'd0', 0.0, 0],
['q0', 'd2', -1.0, 1],
['q0', 'd1', -2.0, 2],
['q0', 'd3', -3.0, 3],
['q1', 'd0', 0.0, 0],
['q2', 'd0', 0.0, 0],
['q2', 'd1', -1.0, 1],
], columns=['qid', 'docno', 'score', 'rank']))


if __name__ == '__main__':
unittest.main()

0 comments on commit 5585426

Please sign in to comment.