Skip to content

Commit

Permalink
Merge pull request #12 from dataiku/metrics
Browse files Browse the repository at this point in the history
Add a metrics module to easily log them
  • Loading branch information
dsleo authored Jun 2, 2020
2 parents 1dc1c19 + d095485 commit ee29df4
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
87 changes: 87 additions & 0 deletions cardinal/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
from abc import ABC, abstractmethod

import numpy as np


class BaseMonitor(ABC):
"""A monitor is a metric and a set of utils to record it and monitor it.
Args:
batch_size: If specified, a warning will be issued if batch_size is not correct
tolerance:
"""

def __init__(self, batch_size=None, tolerance=None):
self.batch_size = batch_size
self.tolerance = tolerance
self.reset()

def reset(self):
self.n_samples = []
self.values = []

def _append_n_samples(self, n_samples):
self.n_samples.append(n_samples)
if not self.batch_size or len(self.n_samples) <= 1:
return
this_batch_size = self.n_samples[-1] - self.n_samples[-2]
if this_batch_size != self.batch_size:
logging.warn(
'Batch size of iteration {} is {} which is different'
'from the reference batch size {}'.format(
len(self.n_samples), this_batch_size, self.batch_size
)
)

@abstractmethod
def accumulate(self, n_samples, value):
pass

@abstractmethod
def get(self):
pass

def is_stalled(self, n_iter=1):
if len(self.values) < n_iter + 1:
return False
for prev_v, curr_v in zip(self.values[-n_iter - 1:-1], self.values[-n_iter]):
if np.abs(curr_v - prev_v) > self.tolerance:
return False
return True


class ContradictionMonitor(BaseMonitor):
"""Stores the amount of contradictions along an experiment
We call contradiction the difference between predictions of two successive
models on an isolated test set.
"""

"""Stores contradiction for a new iteration.
Args:
n_samples : Number of training samples
probas_test : Predictions of shape (n_samples, n_classes)
"""
def accumulate(self, n_samples: int, probas_test: np.array):
if self.last_probas_test is not None:
self.values.append(
np.abs(probas_test - self.last_probas_test).sum())
self._append_n_samples(n_samples)
self.last_probas_test = probas_test

"""Returns the recorded metrics
"""
def get(self):
return {
"n_samples": self.n_samples,
"contradictions": self.values
}

"""Reset the metrics for a new experiment
"""
def reset(self):
super().reset()
self.last_probas_test = None
14 changes: 5 additions & 9 deletions examples/plot_digits_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from cardinal.random import RandomSampler
from cardinal.plotting import plot_confidence_interval
from cardinal.base import BaseQuerySampler
from cardinal.metrics import ContradictionMonitor

np.random.seed(7)

Expand Down Expand Up @@ -147,7 +148,7 @@ def select_samples(self, X):
train_test_split(X, y, test_size=500, random_state=k)

accuracies = []
contradictions = []
contradictions = ContradictionMonitor()
explorations = []

previous_proba = None
Expand All @@ -167,21 +168,16 @@ def select_samples(self, X):
# Record metrics
accuracies.append(model.score(X_test, y_test))
explorations.append(compute_exploration(X_train[mask], X_test))

# Contradictions depend on the previous iteration
current_proba = model.predict_proba(X_test)
if previous_proba is not None:
contradictions.append(compute_contradiction(
previous_proba, current_proba))
previous_proba = current_proba
contradictions.accumulate(len(selected),
model.predict_proba(X_test))

sampler.fit(X_train[mask], y_train[mask])
selected = sampler.select_samples(X_train[~mask])
mask[indices[~mask][selected]] = True

all_accuracies.append(accuracies)
all_explorations.append(explorations)
all_contradictions.append(contradictions)
all_contradictions.append(contradictions.get()['contradictions'])

x_data = np.arange(10, batch_size * (n_iter - 1) + 11, batch_size)

Expand Down

0 comments on commit ee29df4

Please sign in to comment.