-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* runtime-defined measures (see #27) * reworking runtime-defined measures based on feedback from @cmacdonald - Support "cutoff" parameter (default on) - Name optional (defaults to repr of impl) - Runtime-defined measures don't get registered
- Loading branch information
1 parent
095d650
commit fe2012f
Showing
4 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import ir_measures | ||
from ir_measures import providers, measures, Metric | ||
|
||
|
||
class RuntimeProvider(providers.Provider): | ||
""" | ||
Supports measures that are defined at runtime via `ir_measures.define()` and | ||
`ir_measures.define_byquery()`. | ||
""" | ||
NAME = 'runtime' | ||
|
||
def supports(self, measure): | ||
measure.validate_params() | ||
if hasattr(measure, 'runtime_impl'): | ||
return True | ||
return False | ||
|
||
def _evaluator(self, measures, qrels): | ||
measures = ir_measures.util.flatten_measures(measures) | ||
# Convert qrels to dict_of_dict (input format used by pytrec_eval) | ||
qrels = ir_measures.util.QrelsConverter(qrels).as_pd_dataframe() | ||
qrels.sort_values(by=['query_id', 'doc_id'], inplace=True) | ||
return RuntimeEvaluator(measures, qrels) | ||
|
||
|
||
class RuntimeEvaluator(providers.Evaluator): | ||
def __init__(self, measures, qrels): | ||
super().__init__(measures, set(qrels['query_id'].unique())) | ||
self.qrels = qrels | ||
|
||
def _iter_calc(self, run): | ||
run = ir_measures.util.RunConverter(run).as_pd_dataframe() | ||
run.sort_values(by=['query_id', 'score'], ascending=[True, False], inplace=True) | ||
for measure in self.measures: | ||
yield from measure.runtime_impl(self.qrels, run) | ||
|
||
|
||
def define(impl, name=None, support_cutoff=True): | ||
_SUPPORTED_PARAMS = {} | ||
if support_cutoff: | ||
_SUPPORTED_PARAMS['cutoff'] = measures.ParamInfo(dtype=int, required=False, desc='ranking cutoff threshold') | ||
class _RuntimeMeasure(measures.Measure): | ||
nonlocal _SUPPORTED_PARAMS | ||
SUPPORTED_PARAMS = _SUPPORTED_PARAMS | ||
NAME = name | ||
__name__ = name | ||
|
||
def runtime_impl(self, qrels, run): | ||
if 'cutoff' in self.params and self.params['cutoff'] is not None: | ||
cutoff = self.params['cutoff'] | ||
# assumes results already sorted (as is done in RuntimeEvaluator) | ||
run = run.groupby('query_id').head(cutoff).reset_index(drop=True) | ||
for qid, score in impl(qrels, run): | ||
yield Metric(qid, self, score) | ||
return _RuntimeMeasure() | ||
|
||
|
||
def _byquery_impl(impl): | ||
def _wrapped(qrels, run): | ||
for qid, run_subdf in run.groupby("query_id"): | ||
qrels_subdf = qrels[qrels['query_id'] == qid] | ||
res = impl(qrels_subdf, run_subdf) | ||
yield qid, res | ||
return _wrapped | ||
|
||
|
||
def define_byquery(impl, name=None, support_cutoff=True): | ||
return define(_byquery_impl(impl), name or repr(impl), support_cutoff) | ||
|
||
|
||
providers.register(RuntimeProvider()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import unittest | ||
import itertools | ||
import ir_measures | ||
|
||
|
||
class TestRuntime(unittest.TestCase): | ||
|
||
def test_define_byquery(self): | ||
def my_p(qrels, run): | ||
run = run.merge(qrels, 'left', on=['doc_id']) | ||
return (run['relevance'] > 0).sum() / len(run) | ||
def my_s(qrels, run): | ||
run = run.merge(qrels, 'left', on=['doc_id']) | ||
return 1. if (run['relevance'] > 0).sum() else 0. | ||
MyP = ir_measures.define_byquery(my_p) | ||
MyS = ir_measures.define_byquery(my_s) | ||
qrels = list(ir_measures.read_trec_qrels(''' | ||
0 0 D0 0 | ||
0 0 D1 1 | ||
0 0 D2 1 | ||
0 0 D3 2 | ||
0 0 D4 0 | ||
1 0 D0 1 | ||
1 0 D3 0 | ||
1 0 D5 2 | ||
''')) | ||
run = list(ir_measures.read_trec_run(''' | ||
0 0 D0 1 0.8 run | ||
0 0 D2 2 0.7 run | ||
0 0 D1 3 0.3 run | ||
0 0 D3 4 0.4 run | ||
0 0 D4 5 0.1 run | ||
1 0 D1 1 0.8 run | ||
1 0 D3 2 0.7 run | ||
1 0 D4 3 0.3 run | ||
1 0 D2 4 0.4 run | ||
''')) | ||
|
||
result = list((MyP@1).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0) | ||
|
||
result = list((MyP@2).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.5) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25) | ||
|
||
result = list((MyP@3).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.6666666666666666) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333) | ||
|
||
result = list((MyS@2).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 1.) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5) | ||
|
||
def test_define(self): | ||
def my_p(qrels, run): | ||
run = run.merge(qrels, 'left', on=['query_id', 'doc_id']) | ||
for qid, df in run.groupby('query_id'): | ||
yield qid, (df['relevance'] > 0).sum() / len(df) | ||
def my_s(qrels, run): | ||
run = run.merge(qrels, 'left', on=['query_id', 'doc_id']) | ||
for qid, df in run.groupby('query_id'): | ||
yield qid, 1. if (df['relevance'] > 0).sum() else 0. | ||
MyP = ir_measures.define(my_p) | ||
MyS = ir_measures.define(my_s) | ||
qrels = list(ir_measures.read_trec_qrels(''' | ||
0 0 D0 0 | ||
0 0 D1 1 | ||
0 0 D2 1 | ||
0 0 D3 2 | ||
0 0 D4 0 | ||
1 0 D0 1 | ||
1 0 D3 0 | ||
1 0 D5 2 | ||
''')) | ||
run = list(ir_measures.read_trec_run(''' | ||
0 0 D0 1 0.8 run | ||
0 0 D2 2 0.7 run | ||
0 0 D1 3 0.3 run | ||
0 0 D3 4 0.4 run | ||
0 0 D4 5 0.1 run | ||
1 0 D1 1 0.8 run | ||
1 0 D3 2 0.7 run | ||
1 0 D4 3 0.3 run | ||
1 0 D2 4 0.4 run | ||
''')) | ||
result = list((MyP@1).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0) | ||
|
||
result = list((MyP@2).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.5) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25) | ||
|
||
result = list((MyP@3).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 0.6666666666666666) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333) | ||
|
||
result = list((MyS@2).iter_calc(qrels, run)) | ||
self.assertEqual(result[0].query_id, "0") | ||
self.assertEqual(result[0].value, 1.) | ||
self.assertEqual(result[1].query_id, "1") | ||
self.assertEqual(result[1].value, 0.) | ||
self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |