-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Maximal Marginal Relevance (#36)
* mmr implementation * refactor a bit * scores as float * documentation updates * refactor * add test for mmr * ruff * updated documentation
- Loading branch information
1 parent
0c7a334
commit 422c98d
Showing
8 changed files
with
153 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pyterrier as pt | ||
import pyterrier_alpha as pta | ||
|
||
|
||
class MmrScorer(pt.Transformer): | ||
"""An MMR (Maximal Marginal Relevance) scorer (i.e., re-ranker). | ||
The MMR scorer re-orders documents by balancing relevance (from the initial scores) and diversity (based on the | ||
similarity of the document vectors). | ||
.. cite.dblp:: conf/sigir/CarbonellG98 | ||
""" | ||
def __init__(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False): | ||
""" | ||
Args: | ||
Lambda: The balance parameter between relevance and diversity (default: 0.5) | ||
norm_rel: Whether to normalize relevance scores to [0, 1] (default: False) | ||
norm_sim: Whether to normalize similarity scores to [0, 1] (default: False) | ||
drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True) | ||
verbose: Whether to display verbose output (e.g., progress bars) (default: False) | ||
""" | ||
self.Lambda = Lambda | ||
self.norm_rel = norm_rel | ||
self.norm_sim = norm_sim | ||
self.drop_doc_vec = drop_doc_vec | ||
self.verbose = verbose | ||
|
||
def transform(self, inp: pd.DataFrame) -> pd.DataFrame: | ||
pta.validate.result_frame(inp, extra_columns=['doc_vec']) | ||
out = [] | ||
|
||
it = inp.groupby('qid') | ||
if self.verbose: | ||
it = pt.tqdm(it, unit='q', desc=repr(self)) | ||
|
||
for qid, frame in it: | ||
scores = frame['score'].values | ||
dvec_matrix = np.stack(frame['doc_vec']) | ||
dvec_matrix = dvec_matrix / np.linalg.norm(dvec_matrix, axis=1)[:, None] | ||
dvec_sims = dvec_matrix @ dvec_matrix.T | ||
if self.norm_rel: | ||
scores = (scores - scores.min()) / (scores.max() - scores.min()) | ||
if self.norm_sim: | ||
dvec_sims = (dvec_sims - dvec_sims.min()) / (dvec_sims.max() - dvec_sims.min()) | ||
marg_rels = np.zeros_like(scores) | ||
new_idxs = [] | ||
for _ in range(scores.shape[0]): | ||
mmr_scores = (self.Lambda * scores) - ((1 - self.Lambda) * marg_rels) | ||
idx = mmr_scores.argmax() | ||
new_idxs.append(idx) | ||
if marg_rels.shape[0] > 1: | ||
marg_rels = np.max(np.stack([marg_rels, dvec_sims[idx]]), axis=0) | ||
marg_rels[idx] = float('inf') # ignore this document from now on | ||
new_frame = frame.iloc[new_idxs].reset_index(drop=True).assign( | ||
score=-np.arange(len(new_idxs), dtype=float), | ||
rank=np.arange(len(new_idxs)) | ||
) | ||
if self.drop_doc_vec: | ||
new_frame = new_frame.drop(columns='doc_vec') | ||
out.append(new_frame) | ||
|
||
return pd.concat(out, ignore_index=True) | ||
|
||
__repr__ = pta.transformer_repr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pyterrier as pt | ||
import pyterrier_dr | ||
from . import FlexIndex | ||
|
||
|
||
def _mmr(self, *, Lambda: float = 0.5, norm_rel: bool = False, norm_sim: bool = False, drop_doc_vec: bool = True, verbose: bool = False) -> pt.Transformer: | ||
"""Returns an MMR (Maximal Marginal Relevance) scorer (i.e., re-ranker) over this index. | ||
The method first loads vectors from the index and then applies :class:`MmrScorer` to re-rank the results. See | ||
:class:`MmrScorer` for more details on MMR. | ||
Args: | ||
Lambda: The balance parameter between relevance and diversity (default: 0.5) | ||
norm_rel: Whether to normalize relevance scores to [0, 1] (default: False) | ||
norm_sim: Whether to normalize similarity scores to [0, 1] (default: False) | ||
drop_doc_vec: Whether to drop the 'doc_vec' column after re-ranking (default: True) | ||
verbose: Whether to display verbose output (e.g., progress bars) (default: False) | ||
.. cite.dblp:: conf/sigir/CarbonellG98 | ||
""" | ||
return self.vec_loader() >> pyterrier_dr.MmrScorer(Lambda=Lambda, norm_rel=norm_rel, norm_sim=norm_sim, drop_doc_vec=drop_doc_vec, verbose=verbose) | ||
FlexIndex.mmr = _mmr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import unittest | ||
import numpy as np | ||
import pandas as pd | ||
from pyterrier_dr import MmrScorer | ||
|
||
|
||
class TestMmr(unittest.TestCase): | ||
def test_mmr(self): | ||
mmr = MmrScorer() | ||
results = mmr(pd.DataFrame([ | ||
['q0', 'd0', 1.0, np.array([0, 1, 0])], | ||
['q0', 'd1', 0.5, np.array([0, 1, 1])], | ||
['q0', 'd2', 0.5, np.array([1, 1, 1])], | ||
['q0', 'd3', 0.1, np.array([1, 1, 0])], | ||
['q1', 'd0', 0.6, np.array([0, 1, 0])], | ||
['q2', 'd0', 0.4, np.array([0, 1, 0])], | ||
['q2', 'd1', 0.3, np.array([0, 1, 1])], | ||
], columns=['qid', 'docno', 'score', 'doc_vec'])) | ||
pd.testing.assert_frame_equal(results, pd.DataFrame([ | ||
['q0', 'd0', 0.0, 0], | ||
['q0', 'd2', -1.0, 1], | ||
['q0', 'd1', -2.0, 2], | ||
['q0', 'd3', -3.0, 3], | ||
['q1', 'd0', 0.0, 0], | ||
['q2', 'd0', 0.0, 0], | ||
['q2', 'd1', -1.0, 1], | ||
], columns=['qid', 'docno', 'score', 'rank'])) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |