Skip to content

Commit

Permalink
Runtime measures (#37)
Browse files Browse the repository at this point in the history
* runtime-defined measures (see #27)

* reworking runtime-defined measures based on feedback from @cmacdonald
 - Support "cutoff" parameter (default on)
 - Name optional (defaults to repr of impl)
 - Runtime-defined measures don't get registered
  • Loading branch information
seanmacavaney authored Mar 4, 2022
1 parent 095d650 commit fe2012f
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ir_measures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
msmarco = providers.registry['msmarco']
pyndeval = providers.registry['pyndeval']
ranx = providers.registry['ranx']
runtime = providers.registry['runtime']

define = providers.define
define_byquery = providers.define_byquery

CwlMetric = providers.CwlMetric

Expand All @@ -48,6 +52,7 @@
gdeval, # doesn't work when installed from package #9
accuracy,
ranx,
runtime,
])
evaluator = DefaultPipeline.evaluator
calc_ctxt = DefaultPipeline.calc_ctxt # deprecated; replaced with evaluator
Expand Down
1 change: 1 addition & 0 deletions ir_measures/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ def register(provider):
from .trectools_provider import TrectoolsProvider
from .msmarco_provider import MsMarcoProvider
from .ranx_provider import RanxProvider
from .runtime_provider import RuntimeProvider, define, define_byquery
71 changes: 71 additions & 0 deletions ir_measures/providers/runtime_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import ir_measures
from ir_measures import providers, measures, Metric


class RuntimeProvider(providers.Provider):
"""
Supports measures that are defined at runtime via `ir_measures.define()` and
`ir_measures.define_byquery()`.
"""
NAME = 'runtime'

def supports(self, measure):
measure.validate_params()
if hasattr(measure, 'runtime_impl'):
return True
return False

def _evaluator(self, measures, qrels):
measures = ir_measures.util.flatten_measures(measures)
# Convert qrels to dict_of_dict (input format used by pytrec_eval)
qrels = ir_measures.util.QrelsConverter(qrels).as_pd_dataframe()
qrels.sort_values(by=['query_id', 'doc_id'], inplace=True)
return RuntimeEvaluator(measures, qrels)


class RuntimeEvaluator(providers.Evaluator):
def __init__(self, measures, qrels):
super().__init__(measures, set(qrels['query_id'].unique()))
self.qrels = qrels

def _iter_calc(self, run):
run = ir_measures.util.RunConverter(run).as_pd_dataframe()
run.sort_values(by=['query_id', 'score'], ascending=[True, False], inplace=True)
for measure in self.measures:
yield from measure.runtime_impl(self.qrels, run)


def define(impl, name=None, support_cutoff=True):
_SUPPORTED_PARAMS = {}
if support_cutoff:
_SUPPORTED_PARAMS['cutoff'] = measures.ParamInfo(dtype=int, required=False, desc='ranking cutoff threshold')
class _RuntimeMeasure(measures.Measure):
nonlocal _SUPPORTED_PARAMS
SUPPORTED_PARAMS = _SUPPORTED_PARAMS
NAME = name
__name__ = name

def runtime_impl(self, qrels, run):
if 'cutoff' in self.params and self.params['cutoff'] is not None:
cutoff = self.params['cutoff']
# assumes results already sorted (as is done in RuntimeEvaluator)
run = run.groupby('query_id').head(cutoff).reset_index(drop=True)
for qid, score in impl(qrels, run):
yield Metric(qid, self, score)
return _RuntimeMeasure()


def _byquery_impl(impl):
def _wrapped(qrels, run):
for qid, run_subdf in run.groupby("query_id"):
qrels_subdf = qrels[qrels['query_id'] == qid]
res = impl(qrels_subdf, run_subdf)
yield qid, res
return _wrapped


def define_byquery(impl, name=None, support_cutoff=True):
return define(_byquery_impl(impl), name or repr(impl), support_cutoff)


providers.register(RuntimeProvider())
129 changes: 129 additions & 0 deletions test/test_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unittest
import itertools
import ir_measures


class TestRuntime(unittest.TestCase):

def test_define_byquery(self):
def my_p(qrels, run):
run = run.merge(qrels, 'left', on=['doc_id'])
return (run['relevance'] > 0).sum() / len(run)
def my_s(qrels, run):
run = run.merge(qrels, 'left', on=['doc_id'])
return 1. if (run['relevance'] > 0).sum() else 0.
MyP = ir_measures.define_byquery(my_p)
MyS = ir_measures.define_byquery(my_s)
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 0
0 0 D1 1
0 0 D2 1
0 0 D3 2
0 0 D4 0
1 0 D0 1
1 0 D3 0
1 0 D5 2
'''))
run = list(ir_measures.read_trec_run('''
0 0 D0 1 0.8 run
0 0 D2 2 0.7 run
0 0 D1 3 0.3 run
0 0 D3 4 0.4 run
0 0 D4 5 0.1 run
1 0 D1 1 0.8 run
1 0 D3 2 0.7 run
1 0 D4 3 0.3 run
1 0 D2 4 0.4 run
'''))

result = list((MyP@1).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0)

result = list((MyP@2).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.5)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25)

result = list((MyP@3).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.6666666666666666)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333)

result = list((MyS@2).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 1.)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5)

def test_define(self):
def my_p(qrels, run):
run = run.merge(qrels, 'left', on=['query_id', 'doc_id'])
for qid, df in run.groupby('query_id'):
yield qid, (df['relevance'] > 0).sum() / len(df)
def my_s(qrels, run):
run = run.merge(qrels, 'left', on=['query_id', 'doc_id'])
for qid, df in run.groupby('query_id'):
yield qid, 1. if (df['relevance'] > 0).sum() else 0.
MyP = ir_measures.define(my_p)
MyS = ir_measures.define(my_s)
qrels = list(ir_measures.read_trec_qrels('''
0 0 D0 0
0 0 D1 1
0 0 D2 1
0 0 D3 2
0 0 D4 0
1 0 D0 1
1 0 D3 0
1 0 D5 2
'''))
run = list(ir_measures.read_trec_run('''
0 0 D0 1 0.8 run
0 0 D2 2 0.7 run
0 0 D1 3 0.3 run
0 0 D3 4 0.4 run
0 0 D4 5 0.1 run
1 0 D1 1 0.8 run
1 0 D3 2 0.7 run
1 0 D4 3 0.3 run
1 0 D2 4 0.4 run
'''))
result = list((MyP@1).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@1).calc_aggregate(qrels, run), 0.0)

result = list((MyP@2).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.5)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@2).calc_aggregate(qrels, run), 0.25)

result = list((MyP@3).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 0.6666666666666666)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyP@3).calc_aggregate(qrels, run), 0.3333333333333333)

result = list((MyS@2).iter_calc(qrels, run))
self.assertEqual(result[0].query_id, "0")
self.assertEqual(result[0].value, 1.)
self.assertEqual(result[1].query_id, "1")
self.assertEqual(result[1].value, 0.)
self.assertEqual((MyS@2).calc_aggregate(qrels, run), 0.5)


if __name__ == '__main__':
unittest.main()

0 comments on commit fe2012f

Please sign in to comment.