diff --git a/ir_measures/__init__.py b/ir_measures/__init__.py index 20e0b0c..e27039a 100644 --- a/ir_measures/__init__.py +++ b/ir_measures/__init__.py @@ -34,6 +34,10 @@ msmarco = providers.registry['msmarco'] pyndeval = providers.registry['pyndeval'] ranx = providers.registry['ranx'] +runtime = providers.registry['runtime'] + +define = providers.define +define_byquery = providers.define_byquery CwlMetric = providers.CwlMetric @@ -48,6 +52,7 @@ gdeval, # doesn't work when installed from package #9 accuracy, ranx, + runtime, ]) evaluator = DefaultPipeline.evaluator calc_ctxt = DefaultPipeline.calc_ctxt # deprecated; replaced with evaluator diff --git a/ir_measures/providers/__init__.py b/ir_measures/providers/__init__.py index 0a36e28..05fdf17 100644 --- a/ir_measures/providers/__init__.py +++ b/ir_measures/providers/__init__.py @@ -15,3 +15,4 @@ def register(provider): from .trectools_provider import TrectoolsProvider from .msmarco_provider import MsMarcoProvider from .ranx_provider import RanxProvider +from .runtime_provider import RuntimeProvider, define, define_byquery diff --git a/ir_measures/providers/runtime_provider.py b/ir_measures/providers/runtime_provider.py new file mode 100644 index 0000000..c3e5ade --- /dev/null +++ b/ir_measures/providers/runtime_provider.py @@ -0,0 +1,71 @@ +import ir_measures +from ir_measures import providers, measures, Metric + + +class RuntimeProvider(providers.Provider): + """ + Supports measures that are defined at runtime via `ir_measures.define()` and + `ir_measures.define_byquery()`. + """ + NAME = 'runtime' + + def supports(self, measure): + measure.validate_params() + if hasattr(measure, 'runtime_impl'): + return True + return False + + def _evaluator(self, measures, qrels): + measures = ir_measures.util.flatten_measures(measures) + # Convert qrels to dict_of_dict (input format used by pytrec_eval) + qrels = ir_measures.util.QrelsConverter(qrels).as_pd_dataframe() + qrels.sort_values(by=['query_id', 'doc_id'], inplace=True) + return RuntimeEvaluator(measures, qrels) + + +class RuntimeEvaluator(providers.Evaluator): + def __init__(self, measures, qrels): + super().__init__(measures, set(qrels['query_id'].unique())) + self.qrels = qrels + + def _iter_calc(self, run): + run = ir_measures.util.RunConverter(run).as_pd_dataframe() + run.sort_values(by=['query_id', 'score'], ascending=[True, False], inplace=True) + for measure in self.measures: + yield from measure.runtime_impl(self.qrels, run) + + +def define(impl, name=None, support_cutoff=True): + _SUPPORTED_PARAMS = {} + if support_cutoff: + _SUPPORTED_PARAMS['cutoff'] = measures.ParamInfo(dtype=int, required=False, desc='ranking cutoff threshold') + class _RuntimeMeasure(measures.Measure): + nonlocal _SUPPORTED_PARAMS + SUPPORTED_PARAMS = _SUPPORTED_PARAMS + NAME = name + __name__ = name + + def runtime_impl(self, qrels, run): + if 'cutoff' in self.params and self.params['cutoff'] is not None: + cutoff = self.params['cutoff'] + # assumes results already sorted (as is done in RuntimeEvaluator) + run = run.groupby('query_id').head(cutoff).reset_index(drop=True) + for qid, score in impl(qrels, run): + yield Metric(qid, self, score) + return _RuntimeMeasure() + + +def _byquery_impl(impl): + def _wrapped(qrels, run): + for qid, run_subdf in run.groupby("query_id"): + qrels_subdf = qrels[qrels['query_id'] == qid] + res = impl(qrels_subdf, run_subdf) + yield qid, res + return _wrapped + + +def define_byquery(impl, name=None, support_cutoff=True): + return define(_byquery_impl(impl), name or repr(impl), support_cutoff) + + +providers.register(RuntimeProvider()) diff --git a/test/test_runtime.py b/test/test_runtime.py new file mode 100644 index 0000000..769f495 --- /dev/null +++ b/test/test_runtime.py @@ -0,0 +1,129 @@ +import unittest +import itertools +import ir_measures + + +class TestRuntime(unittest.TestCase): + + def test_define_byquery(self): + def my_p(qrels, run): + run = run.merge(qrels, 'left', on=['doc_id']) + return (run['relevance'] > 0).sum() / len(run) + def my_s(qrels, run): + run = run.merge(qrels, 'left', on=['doc_id']) + return 1. if (run['relevance'] > 0).sum() else 0. + MyP = ir_measures.define_byquery(my_p) + MyS = ir_measures.define_byquery(my_s) + qrels = list(ir_measures.read_trec_qrels(''' +0 0 D0 0 +0 0 D1 1 +0 0 D2 1 +0 0 D3 2 +0 0 D4 0 +1 0 D0 1 +1 0 D3 0 +1 0 D5 2 +''')) + run = list(ir_measures.read_trec_run(''' +0 0 D0 1 0.8 run +0 0 D2 2 0.7 run +0 0 D1 3 0.3 run +0 0 D3 4 0.4 run +0 0 D4 5 0.1 run +1 0 D1 1 0.8 run +1 0 D3 2 0.7 run +1 0 D4 3 0.3 run +1 0 D2 4 0.4 run +''')) + + result = list((MyP@1).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0) + + result = list((MyP@2).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.5) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25) + + result = list((MyP@3).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.6666666666666666) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333) + + result = list((MyS@2).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 1.) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5) + + def test_define(self): + def my_p(qrels, run): + run = run.merge(qrels, 'left', on=['query_id', 'doc_id']) + for qid, df in run.groupby('query_id'): + yield qid, (df['relevance'] > 0).sum() / len(df) + def my_s(qrels, run): + run = run.merge(qrels, 'left', on=['query_id', 'doc_id']) + for qid, df in run.groupby('query_id'): + yield qid, 1. if (df['relevance'] > 0).sum() else 0. + MyP = ir_measures.define(my_p) + MyS = ir_measures.define(my_s) + qrels = list(ir_measures.read_trec_qrels(''' +0 0 D0 0 +0 0 D1 1 +0 0 D2 1 +0 0 D3 2 +0 0 D4 0 +1 0 D0 1 +1 0 D3 0 +1 0 D5 2 +''')) + run = list(ir_measures.read_trec_run(''' +0 0 D0 1 0.8 run +0 0 D2 2 0.7 run +0 0 D1 3 0.3 run +0 0 D3 4 0.4 run +0 0 D4 5 0.1 run +1 0 D1 1 0.8 run +1 0 D3 2 0.7 run +1 0 D4 3 0.3 run +1 0 D2 4 0.4 run +''')) + result = list((MyP@1).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0) + + result = list((MyP@2).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.5) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25) + + result = list((MyP@3).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 0.6666666666666666) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333) + + result = list((MyS@2).iter_calc(qrels, run)) + self.assertEqual(result[0].query_id, "0") + self.assertEqual(result[0].value, 1.) + self.assertEqual(result[1].query_id, "1") + self.assertEqual(result[1].value, 0.) + self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5) + + +if __name__ == '__main__': + unittest.main()