-
-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add summation #67
base: develop
Are you sure you want to change the base?
Add summation #67
Conversation
@@ -247,6 +247,21 @@ def __iter__(self): | |||
for uri, component in self.results_: | |||
yield uri, component | |||
|
|||
def __add__(self, other): | |||
cls = self.__class__ | |||
result = cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we find a way to make sure result
is initialized with the same options (e.g. collar
and skip_overlap
for DiarizationErrorRate
instances) as self
and other
?
This probably means adding some kind of sklearn
-like mechanism to clone metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I cribbed this PR from some monkey patching I did a couple of years ago for an internal SAD scoring tool and in that context, default parameter values were being used, so the issue didn't come up. After looking at how sklearn
handles this, maybe we could add a similar method to ensure the resulting instance is initialized with the same arguments as the first summand. If so, I should also document that sum([m1, m2, ...])
assumes all metrics were initialized identically (reasonable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made an initial attempt at a clone
function, which also required me to implement a get_params
method for metrics. NOTE that get_params
makes the assumption that all subclasses of BaseMetric
include **kwargs
in their signature and pass these keyword arguments to the constructor of the superclass (or if multiple superclasses, to one of them). Should this assumption be violated, weirdness could ensue.
An alternate approach would be that used within sklearn
, which bans use of *args
and **kwargs
within constructors and forces each metric to be explicit about its parameters. This would require touching more lines of the codebase, but beyond being a bit of a chore, shouldn't be difficult to implement.
return result | ||
|
||
def __radd__(self, other): | ||
if other == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this how the built-in sum
function initializes its own accumulator?
Would be nice to add a quick comment mentionning this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The built-in sum
actually has a start
parameter that controls the initial value of the summation. As you might gather, the default value is for start
is 0, so I just hard coded that value as an additive identity for metrics. Probably would be good to add a one or two line comment to this effect to save someone having to read up on __radd__
and sum
.
Adding back a "parallel processing" entry in the documentation (mentioning your joblib example, for instance) would be nice to have as well! |
@@ -36,6 +37,51 @@ | |||
from pyannote.metrics.types import Details, MetricComponents | |||
|
|||
|
|||
# TODO: | |||
# - fit/unfit = proper terminology? | |||
def clone(metric): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it a method of BaseMetric
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No objection. Preference for method or staticmethod?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs access to self
, doesn't it? So regular method.
Unless there is something I don't understand about the distinction between method and statcimethod...
cls = self.__class__ | ||
result = cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once clone
becomes a method of BaseMetric
, we would do:
cls = self.__class__ | |
result = cls() | |
result = self.clone() |
def test_clone(): | ||
# Tests that clone creates deep copy of "unfit" metric. | ||
metric = M1(a=10) | ||
metric_new = clone(metric) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once clone
is a method of BaseMetric
:
metric_new = clone(metric) | |
metric_new = self.clone(metric) |
Co-authored-by: Hervé BREDIN <[email protected]>
Co-authored-by: Hervé BREDIN <[email protected]>
Co-authored-by: Hervé BREDIN <[email protected]>
Co-authored-by: Hervé BREDIN <[email protected]>
Overview
Adds following methods to
BaseMetric
to support summation:__add__
__radd__
Motivation
The motivation is two-fold:
Suppose we want to compute metrics for a large volume of data -- sufficiently large that we would like to parallelize the collection of sufficient statistics. Computing the sufficient statistics using
multiprocessing
orjoblib
is straightforward; e.g., usingjoblib
:However, the sufficient statistics for DER computation are now spread across the file-level metrics. Combining these into a single instance reduces to:
Suppose we want to compute metrics not just overall and at a file-level, but by various logical subdivision; e.g., DIHARD III domains. This is now trivial using
pandas
dataframes. E.g. supposedata
contains the following columns:file_id
-- file iddomain
-- domain file is fromder
-- instance ofDiarizationErrorRate
then:
will contain domain-level metrics, from which suitable reports may be generated.