-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ils implementation * rename * refactor * fix typo
- Loading branch information
1 parent
a8be670
commit 0c7a334
Showing
6 changed files
with
222 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Optional, Iterable, Tuple | ||
import numpy as np | ||
import pandas as pd | ||
import ir_measures | ||
import pyterrier as pt | ||
from pyterrier_dr import FlexIndex | ||
|
||
|
||
def ILS(index: FlexIndex, *, name: Optional[str] = None, verbose: bool = False) -> ir_measures.Measure: # noqa: N802 | ||
"""Create an ILS (Intra-List Similarity) measure calculated using the vectors in the provided index. | ||
Higher scores indicate lower diversity in the results. | ||
This measure supports the ``@k`` convention for applying a top-k cutoff before scoring. | ||
Args: | ||
index (FlexIndex): The index to use for loading document vectors. | ||
name (str, optional): The name of the measure (default: "ILS"). | ||
verbose (bool, optional): Whether to display a progress bar. | ||
Returns: | ||
ir_measures.Measure: An ILS measure object. | ||
.. cite.dblp:: conf/www/ZieglerMKL05 | ||
""" | ||
return ir_measures.define(lambda qrels, results: _ils(results, index, verbose=verbose), name=name or 'ILS') | ||
|
||
|
||
def ils(results: pd.DataFrame, index: Optional[FlexIndex] = None, *, verbose: bool = False) -> Iterable[Tuple[str, float]]: | ||
"""Calculate the ILS (Intra-List Similarity) of a set of results. | ||
Higher scores indicate lower diversity in the results. | ||
Args: | ||
results: The result frame to calculate ILS for. | ||
index: The index to use for loading document vectors. Required if `results` does not have a `doc_vec` column. | ||
verbose: Whether to display a progress bar. | ||
Returns: | ||
Iterable[Tuple[str,float]]: An iterable of (qid, ILS) pairs. | ||
.. cite.dblp:: conf/www/ZieglerMKL05 | ||
""" | ||
return _ils(results.rename(columns={'docno': 'doc_id', 'qid': 'query_id'}), index, verbose=verbose) | ||
|
||
|
||
def _ils(results: pd.DataFrame, index: Optional[FlexIndex] = None, *, verbose: bool = False) -> Iterable[Tuple[str, float]]: | ||
res = {} | ||
|
||
if index is not None: | ||
results = index.vec_loader()(results.rename(columns={'doc_id': 'docno'})) | ||
|
||
if 'doc_vec' not in results: | ||
raise ValueError('You must provide index to ils() if results do not have a `doc_vec` column.') | ||
|
||
it = results.groupby('query_id') | ||
if verbose: | ||
it = pt.tqdm(it, unit='q', desc='ILS') | ||
|
||
for qid, frame in it: | ||
if len(frame) > 1: | ||
vec_matrix = np.stack(frame['doc_vec']) | ||
vec_matrix = vec_matrix / np.linalg.norm(vec_matrix, axis=1)[:, None] # normalize vectors | ||
vec_sims = vec_matrix @ vec_matrix.T | ||
upper_right = np.triu_indices(vec_sims.shape[0], k=1) | ||
res[qid] = np.mean(vec_sims[upper_right]) | ||
else: | ||
res[qid] = 0.0 # ILS is ill-defined when there's only one item. | ||
|
||
return res.items() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
Diversity | ||
======================================================= | ||
|
||
``pyterrier-dr`` provides a diversity evaluation measure, :func:`~pyterrier_dr.ILS` (Intra-List Similarity), | ||
which can be used to evaluate the diversity of search results based on the dense vectors of a :class:`~pyterrier_dr.FlexIndex`. | ||
|
||
This measure can be used alongside PyTerrier's built-in evaluation measures in a :func:`pyterrier.Experiment`. | ||
|
||
.. code-block:: python | ||
:caption: Compare the relevance and ILS of lexical and dense retrieval with a PyTerrier Experiment | ||
import pyterrier as pt | ||
from pyterrier.measures import nDCG, R | ||
from pyterrier_dr import FlexIndex, TasB | ||
from pyterrier_pisa import PisaIndex | ||
dataset = pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged') | ||
index = FlexIndex.from_hf('macavaney/msmarco-passage.tasb.flex') | ||
bm25 = PisaIndex.from_hf('macavaney/msmarco-passage.pisa').bm25() | ||
model = TasB.dot() | ||
pt.Experiment( | ||
[ | ||
bm25, | ||
model >> index, | ||
], | ||
dataset.get_topics(), | ||
dataset.get_qrels(), | ||
[nDCG@10, R(rel=2)@1000, index.ILS@10, index.ILS@1000] | ||
) | ||
# name nDCG@10 R(rel=2)@1000 ILS@10 ILS@1000 | ||
# BM25 0.498902 0.755495 0.852248 0.754691 | ||
# TAS-B 0.716068 0.841756 0.889112 0.775415 | ||
.. autofunction:: pyterrier_dr.ILS | ||
.. autofunction:: pyterrier_dr.ils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ This functionality is covered in more detail in the following pages: | |
encoding | ||
indexing-retrieval | ||
prf | ||
diversity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import unittest | ||
import tempfile | ||
import numpy as np | ||
import pandas as pd | ||
from pyterrier_dr import ils, FlexIndex | ||
|
||
|
||
class TestIls(unittest.TestCase): | ||
def test_ils_basic(self): | ||
results = pd.DataFrame([ | ||
['q0', 'd0', np.array([0, 1, 0])], | ||
['q0', 'd1', np.array([0, 1, 1])], | ||
['q0', 'd2', np.array([1, 1, 0])], | ||
['q0', 'd3', np.array([1, 1, 1])], | ||
['q1', 'd0', np.array([0, 1, 0])], | ||
['q2', 'd0', np.array([0, 1, 0])], | ||
['q2', 'd1', np.array([0, 1, 1])], | ||
], columns=['qid', 'docno', 'doc_vec']) | ||
result = dict(ils(results)) | ||
self.assertAlmostEqual(result['q0'], 0.6874, places=3) | ||
self.assertAlmostEqual(result['q1'], 0.0000, places=3) | ||
self.assertAlmostEqual(result['q2'], 0.7071, places=3) | ||
|
||
def test_ils_vec_from_index(self): | ||
with tempfile.TemporaryDirectory() as d: | ||
index = FlexIndex(f'{d}/index.flex') | ||
index.index([ | ||
{'docno': 'd0', 'doc_vec': np.array([0, 1, 0])}, | ||
{'docno': 'd1', 'doc_vec': np.array([0, 1, 1])}, | ||
{'docno': 'd2', 'doc_vec': np.array([1, 1, 0])}, | ||
{'docno': 'd3', 'doc_vec': np.array([1, 1, 1])}, | ||
]) | ||
results = pd.DataFrame([ | ||
['q0', 'd0'], | ||
['q0', 'd1'], | ||
['q0', 'd2'], | ||
['q0', 'd3'], | ||
['q1', 'd0'], | ||
['q2', 'd0'], | ||
['q2', 'd1'], | ||
], columns=['qid', 'docno']) | ||
result = dict(ils(results, index)) | ||
self.assertAlmostEqual(result['q0'], 0.6874, places=3) | ||
self.assertAlmostEqual(result['q1'], 0.0000, places=3) | ||
self.assertAlmostEqual(result['q2'], 0.7071, places=3) | ||
|
||
def test_ils_measure_from_index(self): | ||
with tempfile.TemporaryDirectory() as d: | ||
index = FlexIndex(f'{d}/index.flex') | ||
index.index([ | ||
{'docno': 'd0', 'doc_vec': np.array([0, 1, 0])}, | ||
{'docno': 'd1', 'doc_vec': np.array([0, 1, 1])}, | ||
{'docno': 'd2', 'doc_vec': np.array([1, 1, 0])}, | ||
{'docno': 'd3', 'doc_vec': np.array([1, 1, 1])}, | ||
]) | ||
results = pd.DataFrame([ | ||
['q0', 'd0'], | ||
['q0', 'd1'], | ||
['q0', 'd2'], | ||
['q0', 'd3'], | ||
['q1', 'd0'], | ||
['q2', 'd0'], | ||
['q2', 'd1'], | ||
], columns=['query_id', 'doc_id']) | ||
qrels = pd.DataFrame(columns=['query_id', 'doc_id', 'relevance']) # qrels ignored | ||
result = index.ILS.calc(qrels, results) | ||
self.assertAlmostEqual(result.aggregated, 0.4648, places=3) | ||
self.assertEqual(3, len(result.per_query)) | ||
self.assertEqual(result.per_query[0].query_id, 'q0') | ||
self.assertAlmostEqual(result.per_query[0].value, 0.6874, places=3) | ||
self.assertEqual(result.per_query[1].query_id, 'q1') | ||
self.assertAlmostEqual(result.per_query[1].value, 0.0000, places=3) | ||
self.assertEqual(result.per_query[2].query_id, 'q2') | ||
self.assertAlmostEqual(result.per_query[2].value, 0.7071, places=3) | ||
|
||
def test_ils_measure_from_index_cutoff(self): | ||
with tempfile.TemporaryDirectory() as d: | ||
index = FlexIndex(f'{d}/index.flex') | ||
index.index([ | ||
{'docno': 'd0', 'doc_vec': np.array([0, 1, 0])}, | ||
{'docno': 'd1', 'doc_vec': np.array([0, 1, 1])}, | ||
{'docno': 'd2', 'doc_vec': np.array([1, 1, 0])}, | ||
{'docno': 'd3', 'doc_vec': np.array([1, 1, 1])}, | ||
]) | ||
results = pd.DataFrame([ | ||
['q0', 'd0'], | ||
['q0', 'd1'], | ||
['q0', 'd2'], | ||
['q0', 'd3'], | ||
['q1', 'd0'], | ||
['q2', 'd0'], | ||
['q2', 'd1'], | ||
], columns=['query_id', 'doc_id']) | ||
qrels = pd.DataFrame(columns=['query_id', 'doc_id', 'relevance']) # qrels ignored | ||
result = (index.ILS@2).calc(qrels, results) | ||
self.assertAlmostEqual(result.aggregated, 0.4714, places=3) | ||
self.assertEqual(3, len(result.per_query)) | ||
self.assertEqual(result.per_query[0].query_id, 'q0') | ||
self.assertAlmostEqual(result.per_query[0].value, 0.7071, places=3) | ||
self.assertEqual(result.per_query[1].query_id, 'q1') | ||
self.assertAlmostEqual(result.per_query[1].value, 0.0000, places=3) | ||
self.assertEqual(result.per_query[2].query_id, 'q2') | ||
self.assertAlmostEqual(result.per_query[2].value, 0.7071, places=3) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |