diff --git a/README.md b/README.md index 99ba7f6..7e43859 100644 --- a/README.md +++ b/README.md @@ -78,10 +78,10 @@ qrels = pd.DataFrame([ # any iterable of namedtuples (e.g., list, generator, etc) qrels = [ - ir_measures.GenericQrel("Q0", "D0", 0), - ir_measures.GenericQrel("Q0", "D1", 1), - ir_measures.GenericQrel("Q1", "D0", 0), - ir_measures.GenericQrel("Q1", "D3", 2), + ir_measures.Qrel("Q0", "D0", 0), + ir_measures.Qrel("Q0", "D1", 1), + ir_measures.Qrel("Q1", "D0", 0), + ir_measures.Qrel("Q1", "D3", 2), ] # TREC-formatted qrels file @@ -118,10 +118,10 @@ run = pd.DataFrame([ # any iterable of namedtuples (e.g., list, generator, etc) run = [ - ir_measures.GenericScoredDoc("Q0", "D0", 1.2), - ir_measures.GenericScoredDoc("Q0", "D1", 1.0), - ir_measures.GenericScoredDoc("Q1", "D0", 2.4), - ir_measures.GenericScoredDoc("Q1", "D3", 3.6), + ir_measures.ScoredDoc("Q0", "D0", 1.2), + ir_measures.ScoredDoc("Q0", "D1", 1.0), + ir_measures.ScoredDoc("Q1", "D0", 2.4), + ir_measures.ScoredDoc("Q1", "D3", 3.6), ] ``` diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..77e6fb3 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,35 @@ +API Reference +=========================================== + +Metric Calculation +------------------------------------------- + +.. autofunction:: ir_measures.iter_calc + +.. autofunction:: ir_measures.calc_aggregate + +.. autofunction:: ir_measures.evaluator + +.. autoclass:: ir_measures.providers.Provider + :members: + +.. autoclass:: ir_measures.providers.Evaluator + :members: + +Parsing +------------------------------------------- + +.. autofunction:: ir_measures.parse_measure + +.. autofunction:: ir_measures.parse_trec_measure + +.. autofunction:: ir_measures.read_trec_qrels + +.. autofunction:: ir_measures.read_trec_run + +Data Classes +------------------------------------------- + +.. autoclass:: ir_measures.Metric +.. autoclass:: ir_measures.Qrel +.. autoclass:: ir_measures.ScoredDoc diff --git a/docs/conf.py b/docs/conf.py index 6874a00..43f80be 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,8 +27,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ -] +extensions = ['sphinx.ext.autodoc'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 91d40d6..bfc5232 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -110,14 +110,14 @@ map to (integer) relevance scores:: } } -**namedtuple iterable**: Any iterable of named tuples. You can use ``ir_measures.GenericQrel``, +**namedtuple iterable**: Any iterable of named tuples. You can use ``ir_measures.Qrel``, or any other NamedTuple with the fields ``query_id``, ``doc_id``, and ``relevance``:: qrels = [ - ir_measures.GenericQrel("Q0", "D0", 0), - ir_measures.GenericQrel("Q0", "D1", 1), - ir_measures.GenericQrel("Q1", "D0", 0), - ir_measures.GenericQrel("Q1", "D3", 2), + ir_measures.Qrel("Q0", "D0", 0), + ir_measures.Qrel("Q0", "D1", 1), + ir_measures.Qrel("Q1", "D0", 0), + ir_measures.Qrel("Q1", "D3", 2), ] Note that if the results are an iterator (such as the result of a generator), ``ir_measures`` will consume @@ -182,14 +182,14 @@ map to (float) ranking scores:: } } -**namedtuple iterable**: Any iterable of named tuples. You can use ``ir_measures.GenericScoredDoc``, +**namedtuple iterable**: Any iterable of named tuples. You can use ``ir_measures.ScoredDoc``, or any other NamedTuple with the fields ``query_id``, ``doc_id``, and ``score``:: run = [ - ir_measures.GenericScoredDoc("Q0", "D0", 1.2), - ir_measures.GenericScoredDoc("Q0", "D1", 1.0), - ir_measures.GenericScoredDoc("Q1", "D0", 2.4), - ir_measures.GenericScoredDoc("Q1", "D3", 3.6), + ir_measures.ScoredDoc("Q0", "D0", 1.2), + ir_measures.ScoredDoc("Q0", "D1", 1.0), + ir_measures.ScoredDoc("Q1", "D0", 2.4), + ir_measures.ScoredDoc("Q1", "D3", 3.6), ] Note that if the results are an iterator (such as the result of a generator), ``ir_measures`` will consume @@ -227,6 +227,76 @@ easily make a qrels dataframe that is compatible with ir-measures like so:: Note that ``read_trec_run`` returns a generator. If you need to use the qrels multiple times, wrap it in the ``list`` constructor to read the all qrels into memory. +Measure Objects +--------------------------------------- + +Measure objects speficy the measure you want to calculate, along with any +parameters they may have. There are several ways to create them. The +easiest is to specify them directly in code: + + >>> from ir_measures import * # imports all measure names + >>> AP + AP + >>> AP(rel=2) + AP(rel=2) + >>> nDCG@20 + nDCG@20 + >>> P(rel=2)@10 + P(rel=2)@10 + +Notice that measures can include parameters. For instance, ``AP(rel=2)`` is the +average precision measure with a minimum relevance level of 2 (i.e., documents +need to be scored at least 2 to count as relevant.) Or ``nDCG@20``, which specifies +a ranking cutoff threshold of 20. See the measure's documentation for full details +of available parameters. + +If you need to get a measure object from a string (e.g., if specified by the user +as a command line argument), use the ``ir_measures.parse_measure`` function: + + >>> ir_measures.parse_measure('AP') + AP + >>> ir_measures.parse_measure('AP(rel=2)') + AP(rel=2) + >>> ir_measures.parse_measure('nDCG@20') + nDCG@20 + >>> ir_measures.parse_measure('P(rel=2)@10') + P(rel=2)@10 + +If you are familiar with the measure and family names from ``trec_eval``, you can +map them to measure objects using ``ir_measures.parse_trec_measure()``: + + >>> ir_measures.parse_trec_measure('map') + [AP] + >>> ir_measures.parse_trec_measure('P') # expands to multiple levels + [P@5, P@10, P@15, P@20, P@30, P@100, P@200, P@500, P@1000] + >>> ir_measures.parse_trec_measure('P_3,8') # or 'P.3,8' + [P@3, P@8] + >>> ir_measures.parse_trec_measure('ndcg') + [nDCG] + >>> ir_measures.parse_trec_measure('ndcg_cut_10') + [nDCG@10] + >>> ir_measures.parse_trec_measure('official') + [P@5, P@10, P@15, P@20, P@30, P@100, P@200, P@500, P@1000, Rprec, Bpref, IPrec@0.0, IPrec@0.1, IPrec@0.2, IPrec@0.3, IPrec@0.4, IPrec@0.5, IPrec@0.6, IPrec@0.7, IPrec@0.8, IPrec@0.9, IPrec@1.0, AP, NumQ, NumRel, NumRet(rel=1), NumRet, RR] + +Note that a single ``trec_eval`` measure name can map to multiple measures, +so measures are returned as a list. + +Measures are be passed into methods like ``ir_measures.calc_aggregate``, ``ir_measures.iter_calc``, +and ``ir_measures.evaluator``. You can also calculate values from the measure object itself: + + >>> AP.calc_aggregate(qrels, run) + 0.2842120439595336 + >>> (nDCG@10).calc_aggregate(qrels, run) # parens needed when @cutoff is used + 0.6250748053944134 + >>> for metric in (P(rel=2)@10).iter_calc(qrels, run): + ... print(metric) + Metric(query_id='1', measure=P(rel=2)@10, value=0.5) + Metric(query_id='2', measure=P(rel=2)@10, value=0.8) + ... + Metric(query_id='35', measure=P(rel=2)@10, value=0.9) + + + Scoring multiple runs --------------------------------------- @@ -244,7 +314,6 @@ An evaluator object has ``calc_aggregate(run)`` and ``calc_iter(run)`` methods. {nDCG@10: 0.5286, P@5: 0.6228, P(rel=2)@5: 0.4628, Judged@10: 0.8485} - .. [1] In the examples, ``P@5`` and ``nDCG@10`` are returned first, as they are both calculated in one invocation of ``pytrec_eval``. Then, results for ``P(rel=2)@5`` are returned (as a second invocation of ``pytrec_eval`` because it only supports one relevance level at a time). diff --git a/docs/index.rst b/docs/index.rst index 416e558..b97ae75 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,3 +52,4 @@ Table of Contents getting-started measures providers + api diff --git a/ir_measures/__init__.py b/ir_measures/__init__.py index af5415f..df82409 100644 --- a/ir_measures/__init__.py +++ b/ir_measures/__init__.py @@ -1,6 +1,14 @@ __version__ = "0.1.4" from . import util -from .util import parse_measure, convert_trec_name, read_trec_qrels, read_trec_run, GenericQrel, GenericScoredDoc +from .util import (parse_measure, parse_trec_measure, + read_trec_qrels, read_trec_run, + Qrel, ScoredDoc, Metric, + GenericQrel, # deprecated; replaced with Qrel + GenericScoredDoc, # deprecated; replaced with ScoredDoc + convert_trec_name, # deprecated; replaced with parse_trec_measure + parse_trec_qrels, # deprecated; replaced with read_trec_qrels + parse_trec_run, # deprecated; replaced with read_trec_run + ) from . import measures from .measures import * from . import providers @@ -20,7 +28,7 @@ gdeval, # doesn't work when installed from package #9 ]) evaluator = DefaultPipeline.evaluator -calc_ctxt = DefaultPipeline.calc_ctxt # deprecated +calc_ctxt = DefaultPipeline.calc_ctxt # deprecated; replaced with evaluator iter_calc = DefaultPipeline.iter_calc calc_aggregate = DefaultPipeline.calc_aggregate diff --git a/ir_measures/__main__.py b/ir_measures/__main__.py index e909862..ad7769e 100644 --- a/ir_measures/__main__.py +++ b/ir_measures/__main__.py @@ -2,7 +2,7 @@ import sys import argparse import ir_measures -from ir_measures.util import GenericScoredDoc, GenericQrel +from ir_measures.util import ScoredDoc, Qrel def main_cli(): @@ -15,10 +15,8 @@ def main_cli(): parser.add_argument('--no_summary', '-n', action='store_true') parser.add_argument('--provider', choices=ir_measures.providers.registry.keys()) args = parser.parse_args() - run = (l.split() for l in open(args.run)) - run = (GenericScoredDoc(cols[0], cols[2], float(cols[4])) for cols in run) - qrels = (l.split() for l in open(args.qrels)) - qrels = (GenericQrel(cols[0], cols[2], int(cols[3])) for cols in qrels) + run = ir_measures.read_trec_run(args.run) + qrels = ir_measures.read_trec_qrels(args.qrels) measures, errors = [], [] for mstr in args.measures: for m in mstr.split(): diff --git a/ir_measures/measures/__init__.py b/ir_measures/measures/__init__.py index a005f86..327f0dd 100644 --- a/ir_measures/measures/__init__.py +++ b/ir_measures/measures/__init__.py @@ -9,7 +9,7 @@ def register(measure, aliases=[], name=None): registry[alias] = measure return registry -from .base import BaseMeasure, ParamInfo, MultiMeasures, MeanAgg, SumAgg +from .base import Measure, ParamInfo, MultiMeasures, MeanAgg, SumAgg from .ap import AP, MAP, _AP from .bpref import Bpref, BPref, _Bpref from .err import ERR, _ERR diff --git a/ir_measures/measures/ap.py b/ir_measures/measures/ap.py index a2cd583..0450019 100644 --- a/ir_measures/measures/ap.py +++ b/ir_measures/measures/ap.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _AP(measures.BaseMeasure): +class _AP(measures.Measure): """ The [Mean] Average Precision ([M]AP). The average precision of a single query is the mean of the precision scores at each relevant item returned in a search results list. diff --git a/ir_measures/measures/base.py b/ir_measures/measures/base.py index 3c598a6..3a2b74e 100644 --- a/ir_measures/measures/base.py +++ b/ir_measures/measures/base.py @@ -2,7 +2,7 @@ import ir_measures -class BaseMeasure: +class Measure: NAME = None AT_PARAM = 'cutoff' # allows measures to configure which param measure@X updates (default is cutoff) SUPPORTED_PARAMS = {} @@ -69,7 +69,7 @@ def __repr__(self): return result def __eq__(self, other): - if isinstance(other, BaseMeasure): + if isinstance(other, Measure): return repr(self) == repr(other) return False diff --git a/ir_measures/measures/bpref.py b/ir_measures/measures/bpref.py index 7b58e8a..989567e 100644 --- a/ir_measures/measures/bpref.py +++ b/ir_measures/measures/bpref.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _Bpref(measures.BaseMeasure): +class _Bpref(measures.Measure): """ Binary Preference (Bpref). This measure examines the relative ranks of judged relevant and non-relevant documents. Non-judged documents are not considered. diff --git a/ir_measures/measures/err.py b/ir_measures/measures/err.py index 21c5850..c4158bc 100644 --- a/ir_measures/measures/err.py +++ b/ir_measures/measures/err.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _ERR(measures.BaseMeasure): +class _ERR(measures.Measure): """ The Expected Reciprocal Rank (ERR) is a precision-focused measure. In essence, an extension of reciprocal rank that encapsulates both graded relevance and diff --git a/ir_measures/measures/infap.py b/ir_measures/measures/infap.py index 1924297..687ac12 100644 --- a/ir_measures/measures/infap.py +++ b/ir_measures/measures/infap.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _infAP(measures.BaseMeasure): +class _infAP(measures.Measure): """ Inferred AP. AP implementation that accounts for pooled-but-unjudged documents by assuming that they are relevant at the same proportion as other judged documents. Essentially, skips diff --git a/ir_measures/measures/iprec.py b/ir_measures/measures/iprec.py index 369ae90..f3a9a00 100644 --- a/ir_measures/measures/iprec.py +++ b/ir_measures/measures/iprec.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _IPrec(measures.BaseMeasure): +class _IPrec(measures.Measure): """ Interpolated Precision at a given recall cutoff. Used for building precision-recall graphs. Unlike most measures, where @ indicates an absolute cutoff threshold, here @ sets the recall diff --git a/ir_measures/measures/judged.py b/ir_measures/measures/judged.py index 08df891..bbe65bb 100644 --- a/ir_measures/measures/judged.py +++ b/ir_measures/measures/judged.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _Judged(measures.BaseMeasure): +class _Judged(measures.Measure): """ Percentage of results in the top k (cutoff) results that have relevance judgments. Equivalent to P@k with a rel lower than any judgment. diff --git a/ir_measures/measures/ndcg.py b/ir_measures/measures/ndcg.py index d6bb3d9..214cdca 100644 --- a/ir_measures/measures/ndcg.py +++ b/ir_measures/measures/ndcg.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _nDCG(measures.BaseMeasure): +class _nDCG(measures.Measure): """ The normalized Discounted Cumulative Gain (nDCG). Uses graded labels - systems that put the highest graded documents at the top of the ranking. diff --git a/ir_measures/measures/numq.py b/ir_measures/measures/numq.py index 97a5046..abdbed8 100644 --- a/ir_measures/measures/numq.py +++ b/ir_measures/measures/numq.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo, SumAgg +from .base import Measure, ParamInfo, SumAgg -class _NumQ(measures.BaseMeasure): +class _NumQ(measures.Measure): """ The total number of queries. """ diff --git a/ir_measures/measures/numrel.py b/ir_measures/measures/numrel.py index 0547b2f..20c845d 100644 --- a/ir_measures/measures/numrel.py +++ b/ir_measures/measures/numrel.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo, SumAgg +from .base import Measure, ParamInfo, SumAgg -class _NumRel(measures.BaseMeasure): +class _NumRel(measures.Measure): """ The number of relevant documents the query has (independent of what the system retrieved). """ diff --git a/ir_measures/measures/numret.py b/ir_measures/measures/numret.py index 059b7f4..5dd6f86 100644 --- a/ir_measures/measures/numret.py +++ b/ir_measures/measures/numret.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo, SumAgg +from .base import Measure, ParamInfo, SumAgg -class _NumRet(measures.BaseMeasure): +class _NumRet(measures.Measure): """ The number of results returned. When rel is provided, counts the number of documents returned with at least that relevance score (inclusive). diff --git a/ir_measures/measures/p.py b/ir_measures/measures/p.py index 4d0f70a..6d10224 100644 --- a/ir_measures/measures/p.py +++ b/ir_measures/measures/p.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _P(measures.BaseMeasure): +class _P(measures.Measure): """ Basic measure for that computes the percentage of documents in the top cutoff results that are labeled as relevant. cutoff is a required parameter, and can be provided as diff --git a/ir_measures/measures/r.py b/ir_measures/measures/r.py index d29f902..19725df 100644 --- a/ir_measures/measures/r.py +++ b/ir_measures/measures/r.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _R(measures.BaseMeasure): +class _R(measures.Measure): """ Recall@k (R@k). The fraction of relevant documents for a query that have been retrieved by rank k. """ diff --git a/ir_measures/measures/rbp.py b/ir_measures/measures/rbp.py index 7f75244..1667240 100644 --- a/ir_measures/measures/rbp.py +++ b/ir_measures/measures/rbp.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _RBP(measures.BaseMeasure): +class _RBP(measures.Measure): """ The Rank-Biased Precision (RBP) TODO: write diff --git a/ir_measures/measures/rprec.py b/ir_measures/measures/rprec.py index 8802d82..0ed1d4c 100644 --- a/ir_measures/measures/rprec.py +++ b/ir_measures/measures/rprec.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _Rprec(measures.BaseMeasure): +class _Rprec(measures.Measure): """ The precision of at R, where R is the number of relevant documents for a given query. Has the cute property that it is also the recall at R. diff --git a/ir_measures/measures/rr.py b/ir_measures/measures/rr.py index f90abac..82bfb0c 100644 --- a/ir_measures/measures/rr.py +++ b/ir_measures/measures/rr.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _RR(measures.BaseMeasure): +class _RR(measures.Measure): """ The [Mean] Reciprocal Rank ([M]RR) is a precision-focused measure that scores based on the reciprocal of the rank of the highest-scoring relevance document. An optional cutoff can be provided to limit the diff --git a/ir_measures/measures/setp.py b/ir_measures/measures/setp.py index ac75426..80f9d1b 100644 --- a/ir_measures/measures/setp.py +++ b/ir_measures/measures/setp.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _SetP(measures.BaseMeasure): +class _SetP(measures.Measure): """ The Set Precision (SetP); i.e., the number of relevant docs divided by the total number retrieved """ diff --git a/ir_measures/measures/success.py b/ir_measures/measures/success.py index bd84c5a..45881b3 100644 --- a/ir_measures/measures/success.py +++ b/ir_measures/measures/success.py @@ -1,8 +1,8 @@ from ir_measures import measures -from .base import BaseMeasure, ParamInfo +from .base import Measure, ParamInfo -class _Success(measures.BaseMeasure): +class _Success(measures.Measure): """ 1 if a document with at least rel relevance is found in the first cutoff documents, else 0. """ diff --git a/ir_measures/providers/__init__.py b/ir_measures/providers/__init__.py index 074aa72..7b21c91 100644 --- a/ir_measures/providers/__init__.py +++ b/ir_measures/providers/__init__.py @@ -3,7 +3,7 @@ def register(provider): registry[provider.NAME] = provider return provider -from .base import MeasureProvider, BaseMeasureEvaluator +from .base import Provider, Evaluator from .fallback_provider import FallbackProvider from .pytrec_eval_provider import PytrecEvalProvider from .judged_provider import JudgedProvider diff --git a/ir_measures/providers/base.py b/ir_measures/providers/base.py index 91a982d..2b7a99d 100644 --- a/ir_measures/providers/base.py +++ b/ir_measures/providers/base.py @@ -1,17 +1,44 @@ import deprecation import contextlib -from collections import namedtuple +from typing import Iterator, Dict, Union from ir_measures import providers, measures -class MeasureProvider: +class Evaluator: + """ + The base class for scoring runs for a given set of measures and qrels. + Returned from ``.evaluator(measures, qrels)`` calls. + """ + def __init__(self, measures): + self.measures = measures + + def iter_calc(self, run) -> Iterator['Metric']: + """ + Yields per-topic metrics this run. + """ + raise NotImplementedError() + + def calc_aggregate(self, run) -> Dict[measures.Measure, Union[float, int]]: + """ + Returns aggregated measure values for this run. + """ + aggregators = {m: m.aggregator() for m in self.measures} + for metric in self.iter_calc(run): + aggregators[metric.measure].add(metric.value) + return {m: agg.result() for m, agg in aggregators.items()} + + +class Provider: + """ + The base class for all measure providers (e.g., pytrec_eval, gdeval, etc.). + """ NAME = None SUPPORTED_MEASURES = [] def __init__(self): self._is_available = None - def evaluator(self, measures, qrels): + def evaluator(self, measures, qrels) -> Evaluator: if self.is_available(): return self._evaluator(measures, qrels) else: @@ -71,21 +98,6 @@ def initialize(self): pass -class BaseMeasureEvaluator: - def __init__(self, measures): - self.measures = measures - - def iter_calc(self, run): - raise NotImplementedError() - - def calc_aggregate(self, run): - aggregators = {m: m.aggregator() for m in self.measures} - for metric in self.iter_calc(run): - aggregators[metric.measure].add(metric.value) - return {m: agg.result() for m, agg in aggregators.items()} - - - class ParamSpec: def validate(self, value): raise NotImplementedError() @@ -119,5 +131,3 @@ def __repr__(self): return repr(self.choices) NOT_PROVIDED = measures.base._NOT_PROVIDED - -Metric = namedtuple('Metric', ['query_id', 'measure', 'value']) diff --git a/ir_measures/providers/fallback_provider.py b/ir_measures/providers/fallback_provider.py index 131a7fc..c6cfff0 100644 --- a/ir_measures/providers/fallback_provider.py +++ b/ir_measures/providers/fallback_provider.py @@ -3,7 +3,7 @@ from ir_measures.util import QrelsConverter, RunConverter, flatten_measures -class FallbackProvider(providers.MeasureProvider): +class FallbackProvider(providers.Provider): def __init__(self, providers): super().__init__() self.providers = providers @@ -39,7 +39,7 @@ def supports(self, measure): return any(p.is_available() and p.supports(measure) for p in self.providers) -class FallbackEvaluator(providers.BaseMeasureEvaluator): +class FallbackEvaluator(providers.Evaluator): def __init__(self, measures, evaluators): super().__init__(measures) self.evaluators = evaluators diff --git a/ir_measures/providers/gdeval_provider.py b/ir_measures/providers/gdeval_provider.py index 47cf331..4d0bd7d 100644 --- a/ir_measures/providers/gdeval_provider.py +++ b/ir_measures/providers/gdeval_provider.py @@ -4,11 +4,11 @@ import tempfile import contextlib import ir_measures -from ir_measures import providers, measures -from ir_measures.providers.base import Any, Choices, Metric, NOT_PROVIDED +from ir_measures import providers, measures, Metric +from ir_measures.providers.base import Any, Choices, NOT_PROVIDED -class GdevalProvider(providers.MeasureProvider): +class GdevalProvider(providers.Provider): """ gdeval """ @@ -41,7 +41,7 @@ def initialize(self): raise RuntimeError('perl not available', ex) -class GdevalEvaluator(providers.BaseMeasureEvaluator): +class GdevalEvaluator(providers.Evaluator): def __init__(self, measures, qrels, invocations): super().__init__(measures) self.qrels = qrels diff --git a/ir_measures/providers/judged_provider.py b/ir_measures/providers/judged_provider.py index 96a58ca..7a974bc 100644 --- a/ir_measures/providers/judged_provider.py +++ b/ir_measures/providers/judged_provider.py @@ -1,10 +1,10 @@ import contextlib import ir_measures -from ir_measures import providers, measures -from ir_measures.providers.base import Any, Choices, Metric, NOT_PROVIDED +from ir_measures import providers, measures, Metric +from ir_measures.providers.base import Any, Choices, NOT_PROVIDED -class JudgedProvider(providers.MeasureProvider): +class JudgedProvider(providers.Provider): """ python implementation of judgment rate """ @@ -23,7 +23,7 @@ def _evaluator(self, measures, qrels): return JudgedEvaluator(measures, qrels, cutoffs) -class JudgedEvaluator(providers.BaseMeasureEvaluator): +class JudgedEvaluator(providers.Evaluator): def __init__(self, measures, qrels, cutoffs): super().__init__(measures) self.qrels = ir_measures.util.QrelsConverter(qrels).as_dict_of_dict() diff --git a/ir_measures/providers/msmarco_provider.py b/ir_measures/providers/msmarco_provider.py index f41bc87..fdd035c 100644 --- a/ir_measures/providers/msmarco_provider.py +++ b/ir_measures/providers/msmarco_provider.py @@ -1,11 +1,11 @@ import contextlib import ir_measures -from ir_measures import providers, measures -from ir_measures.providers.base import Any, Choices, Metric, NOT_PROVIDED +from ir_measures import providers, measures, Metric +from ir_measures.providers.base import Any, Choices, NOT_PROVIDED from ir_measures.bin import msmarco_eval import sys -class MsMarcoProvider(providers.MeasureProvider): +class MsMarcoProvider(providers.Provider): """ MS MARCO's implementation of RR """ @@ -27,7 +27,7 @@ def _evaluator(self, measures, qrels): return MsMarcoEvaluator(measures, qrels, invocations) -class MsMarcoEvaluator(providers.BaseMeasureEvaluator): +class MsMarcoEvaluator(providers.Evaluator): def __init__(self, measures, qrels, invocations): super().__init__(measures) self.qrels_by_rel = {rel: {} for _, rel, _ in invocations} diff --git a/ir_measures/providers/pytrec_eval_provider.py b/ir_measures/providers/pytrec_eval_provider.py index c093477..2224110 100644 --- a/ir_measures/providers/pytrec_eval_provider.py +++ b/ir_measures/providers/pytrec_eval_provider.py @@ -1,10 +1,10 @@ import contextlib import ir_measures -from ir_measures import providers, measures -from ir_measures.providers.base import Any, Choices, Metric, NOT_PROVIDED +from ir_measures import providers, measures, Metric +from ir_measures.providers.base import Any, Choices, NOT_PROVIDED -class PytrecEvalProvider(providers.MeasureProvider): +class PytrecEvalProvider(providers.Provider): """ pytrec_eval @@ -142,7 +142,7 @@ def initialize(self): raise RuntimeError('pytrec_eval not available', ex) -class PytrecEvalEvaluator(providers.BaseMeasureEvaluator): +class PytrecEvalEvaluator(providers.Evaluator): def __init__(self, measures, invokers): super().__init__(measures) self.invokers = invokers diff --git a/ir_measures/providers/trectools_provider.py b/ir_measures/providers/trectools_provider.py index 64eb0a5..3ca48ee 100644 --- a/ir_measures/providers/trectools_provider.py +++ b/ir_measures/providers/trectools_provider.py @@ -2,11 +2,11 @@ import contextlib import functools import ir_measures -from ir_measures import providers, measures -from ir_measures.providers.base import Any, Choices, Metric, NOT_PROVIDED +from ir_measures import providers, measures, Metric +from ir_measures.providers.base import Any, Choices, NOT_PROVIDED -class TrectoolsProvider(providers.MeasureProvider): +class TrectoolsProvider(providers.Provider): """ trectools @@ -107,7 +107,7 @@ def initialize(self): raise RuntimeError('trectools not available', ex) -class TrectoolsEvaluator(providers.BaseMeasureEvaluator): +class TrectoolsEvaluator(providers.Evaluator): def __init__(self, measures, qrels, invocations, trectools): super().__init__(measures) self.qrels = qrels diff --git a/ir_measures/util.py b/ir_measures/util.py index 0e9a721..9e484ee 100644 --- a/ir_measures/util.py +++ b/ir_measures/util.py @@ -7,11 +7,36 @@ import tempfile from typing import List from collections import namedtuple +from typing import NamedTuple, Union import ir_measures +class Qrel(NamedTuple): + query_id: str + doc_id: str + relevance: int -GenericQrel = namedtuple('GenericQrel', ['query_id', 'doc_id', 'relevance']) -GenericScoredDoc = namedtuple('GenericScoredDoc', ['query_id', 'doc_id', 'score']) +class ScoredDoc(NamedTuple): + query_id: str + doc_id: str + score: float + +class Metric(NamedTuple): + query_id: str + measure: 'Measure' + value: Union[float, int] + + +@deprecation.deprecated(deprecated_in="0.2.0", + details="Please use ir_measures.Qrel() instead") +class GenericQrel(Qrel): + pass +GenericQrel._fields = Qrel._fields + +@deprecation.deprecated(deprecated_in="0.2.0", + details="Please use ir_measures.ScoredDoc() instead") +class GenericScoredDoc(ScoredDoc): + pass +GenericScoredDoc._fields = ScoredDoc._fields class QrelsConverter: @@ -34,7 +59,7 @@ def predict_type(self): result = 'dict_of_dict' elif hasattr(self.qrels, 'itertuples'): cols = self.qrels.columns - if all(i in cols for i in GenericQrel._fields): + if all(i in cols for i in Qrel._fields): result = 'pd_dataframe' elif hasattr(self.qrels, '__iter__'): # peek @@ -43,7 +68,7 @@ def predict_type(self): sentinal = object() item = next(peek_qrels, sentinal) if isinstance(item, tuple) and hasattr(item, '_fields'): - if all(i in item._fields for i in GenericQrel._fields): + if all(i in item._fields for i in Qrel._fields): result = 'namedtuple_iter' elif item is sentinal: result = 'namedtuple_iter' @@ -69,7 +94,7 @@ def as_namedtuple_iter(self): if t == 'dict_of_dict': for query_id, docs in self.qrels.items(): for doc_id, relevance in docs.items(): - yield GenericQrel(query_id=query_id, doc_id=doc_id, relevance=relevance) + yield Qrel(query_id=query_id, doc_id=doc_id, relevance=relevance) if t == 'pd_dataframe': yield from self.qrels.itertuples() if t == 'UNKNOWN': @@ -111,7 +136,7 @@ def predict_type(self): return 'dict_of_dict' if hasattr(self.run, 'itertuples'): cols = self.run.columns - if all(i in cols for i in GenericScoredDoc._fields): + if all(i in cols for i in ScoredDoc._fields): return 'pd_dataframe' if hasattr(self.run, '__iter__'): # peek @@ -120,7 +145,7 @@ def predict_type(self): sentinal = object() item = next(peek_run, sentinal) if isinstance(item, tuple) and hasattr(item, '_fields'): - if all(i in item._fields for i in GenericScoredDoc._fields): + if all(i in item._fields for i in ScoredDoc._fields): return 'namedtuple_iter' if item is sentinal: return 'namedtuple_iter' @@ -145,7 +170,7 @@ def as_namedtuple_iter(self): if t == 'dict_of_dict': for query_id, docs in self.run.items(): for doc_id, score in docs.items(): - yield GenericScoredDoc(query_id=query_id, doc_id=doc_id, score=score) + yield ScoredDoc(query_id=query_id, doc_id=doc_id, score=score) if t == 'pd_dataframe': yield from self.run.itertuples() if t == 'UNKNOWN': @@ -184,7 +209,7 @@ def read_trec_qrels(file): for line in file: if line.strip(): query_id, iteration, doc_id, relevance = line.split() - yield GenericQrel(query_id=query_id, doc_id=doc_id, relevance=int(relevance)) + yield Qrel(query_id=query_id, doc_id=doc_id, relevance=int(relevance)) elif isinstance(file, str): if '\n' in file: yield from read_trec_qrels(io.StringIO(file)) @@ -202,7 +227,7 @@ def read_trec_run(file): for line in file: if line.strip(): query_id, iteration, doc_id, rank, score, tag = line.split() - yield GenericScoredDoc(query_id=query_id, doc_id=doc_id, score=float(score)) + yield ScoredDoc(query_id=query_id, doc_id=doc_id, score=float(score)) elif isinstance(file, str): if '\n' in file: yield from read_trec_run(io.StringIO(file)) @@ -233,7 +258,7 @@ def _ast_to_value(node): raise ValueError(_AST_PARSE_ERROR.format('values must be str, float, int, bool, etc.')) -def parse_measure(measure: str) -> 'BaseMeasure': +def parse_measure(measure: str) -> 'Measure': try: node = ast.parse(measure).body except SyntaxError as e: @@ -266,9 +291,7 @@ def parse_measure(measure: str) -> 'BaseMeasure': return measure(**args) - - -def convert_trec_name(measure: str) -> List['BaseMeasure']: +def parse_trec_measure(measure: str) -> List['Measure']: TREC_NAME_MAP = { 'ndcg': (ir_measures.nDCG, None, None), 'P': (ir_measures.P, 'cutoff', [5, 10, 15, 20, 30, 100, 200, 500, 1000]), @@ -311,7 +334,7 @@ def convert_trec_name(measure: str) -> List['BaseMeasure']: skipped = [] for sub_name in sorted(pytrec_eval.supported_nicknames[measure]): try: - result += convert_trec_name(sub_name) + result += parse_trec_measure(sub_name) except ValueError: if sub_name != 'runid': skipped.append(sub_name) @@ -353,3 +376,9 @@ def convert_trec_name(measure: str) -> List['BaseMeasure']: for arg in meas_args.split(','): result.append(meas(**{arg_name: dtype(arg)})) return result + + +@deprecation.deprecated(deprecated_in="0.2.0", + details="Please use ir_measures.parse_trec_measure() instead") +def convert_trec_name(measure: str) -> List['Measure']: + return parse_trec_measure(measure) diff --git a/test/test_util.py b/test/test_util.py index 761fb9e..c27816e 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -5,7 +5,7 @@ class TestUtil(unittest.TestCase): - def test_convert_trec_name(self): + def test_parse_trec_measure(self): cases = { 'map': [AP], 'P_5': [P@5], @@ -23,7 +23,7 @@ def test_convert_trec_name(self): } for case in cases: with self.subTest(case): - self.assertEqual(ir_measures.convert_trec_name(case), cases[case]) + self.assertEqual(ir_measures.parse_trec_measure(case), cases[case]) def test_parse_measure(self): tests = {