diff --git a/cardinal/metrics.py b/cardinal/metrics.py new file mode 100644 index 0000000..542e1c7 --- /dev/null +++ b/cardinal/metrics.py @@ -0,0 +1,87 @@ +import logging +from abc import ABC, abstractmethod + +import numpy as np + + +class BaseMonitor(ABC): + """A monitor is a metric and a set of utils to record it and monitor it. + + Args: + batch_size: If specified, a warning will be issued if batch_size is not correct + tolerance: + + """ + + def __init__(self, batch_size=None, tolerance=None): + self.batch_size = batch_size + self.tolerance = tolerance + self.reset() + + def reset(self): + self.n_samples = [] + self.values = [] + + def _append_n_samples(self, n_samples): + self.n_samples.append(n_samples) + if not self.batch_size or len(self.n_samples) <= 1: + return + this_batch_size = self.n_samples[-1] - self.n_samples[-2] + if this_batch_size != self.batch_size: + logging.warn( + 'Batch size of iteration {} is {} which is different' + 'from the reference batch size {}'.format( + len(self.n_samples), this_batch_size, self.batch_size + ) + ) + + @abstractmethod + def accumulate(self, n_samples, value): + pass + + @abstractmethod + def get(self): + pass + + def is_stalled(self, n_iter=1): + if len(self.values) < n_iter + 1: + return False + for prev_v, curr_v in zip(self.values[-n_iter - 1:-1], self.values[-n_iter]): + if np.abs(curr_v - prev_v) > self.tolerance: + return False + return True + + +class ContradictionMonitor(BaseMonitor): + """Stores the amount of contradictions along an experiment + + We call contradiction the difference between predictions of two successive + models on an isolated test set. + """ + + """Stores contradiction for a new iteration. + + Args: + n_samples : Number of training samples + probas_test : Predictions of shape (n_samples, n_classes) + """ + def accumulate(self, n_samples: int, probas_test: np.array): + if self.last_probas_test is not None: + self.values.append( + np.abs(probas_test - self.last_probas_test).sum()) + self._append_n_samples(n_samples) + self.last_probas_test = probas_test + + """Returns the recorded metrics + """ + def get(self): + return { + "n_samples": self.n_samples, + "contradictions": self.values + } + + """Reset the metrics for a new experiment + """ + def reset(self): + super().reset() + self.last_probas_test = None diff --git a/examples/plot_digits_metrics.py b/examples/plot_digits_metrics.py index 4d54a0e..01780b7 100644 --- a/examples/plot_digits_metrics.py +++ b/examples/plot_digits_metrics.py @@ -26,6 +26,7 @@ from cardinal.random import RandomSampler from cardinal.plotting import plot_confidence_interval from cardinal.base import BaseQuerySampler +from cardinal.metrics import ContradictionMonitor np.random.seed(7) @@ -147,7 +148,7 @@ def select_samples(self, X): train_test_split(X, y, test_size=500, random_state=k) accuracies = [] - contradictions = [] + contradictions = ContradictionMonitor() explorations = [] previous_proba = None @@ -167,13 +168,8 @@ def select_samples(self, X): # Record metrics accuracies.append(model.score(X_test, y_test)) explorations.append(compute_exploration(X_train[mask], X_test)) - - # Contradictions depend on the previous iteration - current_proba = model.predict_proba(X_test) - if previous_proba is not None: - contradictions.append(compute_contradiction( - previous_proba, current_proba)) - previous_proba = current_proba + contradictions.accumulate(len(selected), + model.predict_proba(X_test)) sampler.fit(X_train[mask], y_train[mask]) selected = sampler.select_samples(X_train[~mask]) @@ -181,7 +177,7 @@ def select_samples(self, X): all_accuracies.append(accuracies) all_explorations.append(explorations) - all_contradictions.append(contradictions) + all_contradictions.append(contradictions.get()['contradictions']) x_data = np.arange(10, batch_size * (n_iter - 1) + 11, batch_size)