From 28959485f7c41b078d8ca7443b9bf377f164282c Mon Sep 17 00:00:00 2001 From: Oege Dijk Date: Sun, 31 Jan 2021 14:35:33 +0100 Subject: [PATCH] update shape precision component --- RELEASE_NOTES.md | 7 +- .../classifier_components.py | 15 +++- explainerdashboard/explainers.py | 84 +++++++++++++------ 3 files changed, 77 insertions(+), 29 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 41331b0..e39b759 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,7 +1,8 @@ # Release Notes ## Version 0.3.1: - +This version is mostly about pre-calculating and optimizing the classifier statistics +components. Those components should now be much more responsive with large datasets. ### New Features - new methods `roc_auc_curve(pos_label)` and `pr_auc_curve(pos_label)` @@ -29,7 +30,11 @@ - dashboard should be more responsive for large datasets - pre-calculating confusion matrices - dashboard should be more responsive for large datasets +- pre-calculating classification_dfs + - dashboard should be more responsive for large datasets - confusion matrix: added axis title, moved predicted labels to bottom of graph +- precision plot: when only adjusting cutoff, simply updating the cutoff + line, without recalculating the plot. ### Other Changes - diff --git a/explainerdashboard/dashboard_components/classifier_components.py b/explainerdashboard/dashboard_components/classifier_components.py index 20114d6..641bd33 100644 --- a/explainerdashboard/dashboard_components/classifier_components.py +++ b/explainerdashboard/dashboard_components/classifier_components.py @@ -297,6 +297,8 @@ def __init__(self, explainer, title="Prediction", name=None, Shows the predicted probability for each {self.explainer.target} label. """ + self.register_dependencies("metrics") + def layout(self): return dbc.Card([ make_hideable( @@ -590,9 +592,18 @@ def update_div_visibility(bins_or_quantiles): Input('precision-cutoff-'+self.name, 'value'), Input('precision-multiclass-'+self.name, 'value'), Input('pos-label-'+self.name, 'value')], - #[State('tabs', 'value')], + [State('precision-graph-'+self.name, 'figure')], ) - def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_label): + def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_label, fig): + ctx = dash.callback_context + trigger = ctx.triggered[0]['prop_id'].split('.')[0] + if trigger == 'precision-cutoff-'+self.name and fig is not None: + return go.Figure(fig).update_shapes(dict( + type='line', + xref='x', yref='y2', + x0=cutoff, x1=cutoff, + y0=0, y1=1.0, + )) if bins == 'bin_size': return self.explainer.plot_precision( bin_size=bin_size, cutoff=cutoff, diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index fcc0de1..b4182ef 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -2068,19 +2068,35 @@ def metrics(self, cutoff=0.5, pos_label=None): """ if self.y_missing: raise ValueError("No y was passed to explainer, so cannot calculate metrics!") - y_true = self.y_binary(pos_label) - y_pred = np.where(self.pred_probas(pos_label) > cutoff, 1, 0) - metrics_dict = { - 'accuracy' : accuracy_score(y_true, y_pred), - 'precision' : precision_score(y_true, y_pred), - 'recall' : recall_score(y_true, y_pred), - 'f1' : f1_score(y_true, y_pred), - 'roc_auc_score' : roc_auc_score(y_true, self.pred_probas(pos_label)), - 'pr_auc_score' : average_precision_score(y_true, self.pred_probas(pos_label)), - 'log_loss' : log_loss(y_true, self.pred_probas(pos_label)) - } - return metrics_dict + def get_metrics(cutoff, pos_label): + y_true = self.y_binary(pos_label) + y_pred = np.where(self.pred_probas(pos_label) > cutoff, 1, 0) + + metrics_dict = { + 'accuracy' : accuracy_score(y_true, y_pred), + 'precision' : precision_score(y_true, y_pred, zero_division=0), + 'recall' : recall_score(y_true, y_pred), + 'f1' : f1_score(y_true, y_pred), + 'roc_auc_score' : roc_auc_score(y_true, self.pred_probas(pos_label)), + 'pr_auc_score' : average_precision_score(y_true, self.pred_probas(pos_label)), + 'log_loss' : log_loss(y_true, self.pred_probas(pos_label)) + } + return metrics_dict + + if not hasattr(self, "_metrics"): + _ = self.pred_probas() + print("Calculating metrics...", flush=True) + self._metrics = dict() + for label in range(len(self.labels)): + self._metrics[label] = dict() + for cut in np.linspace(0.01, 0.99, 99): + self._metrics[label][np.round(cut, 2)] = \ + get_metrics(cut, label) + if cutoff in self._metrics[pos_label]: + return self._metrics[pos_label][cutoff] + else: + return get_metrics(cutoff, pos_label) @insert_pos_label def metrics_descriptions(self, cutoff=0.5, round=3, pos_label=None): @@ -2253,27 +2269,41 @@ def get_liftcurve_df(self, pos_label=None): return self._liftcurve_dfs[pos_label] @insert_pos_label - def get_classification_df(self, cutoff=0.5, percentage=False, pos_label=None): + def get_classification_df(self, cutoff=0.5, pos_label=None): """Returns a dataframe with number of observations in each class above and below the cutoff. Args: cutoff (float, optional): Cutoff to split on. Defaults to 0.5. - percentage (bool, optional): Normalize results. Defaults to False. pos_label (int, optional): Pos label to generate dataframe for. Defaults to self.pos_label. Returns: pd.DataFrame """ - clas_df = pd.DataFrame(index=pd.RangeIndex(0, len(self.labels))) - clas_df['below'] = self.y[self.pred_probas(pos_label) < cutoff].value_counts(normalize=percentage) - clas_df['above'] = self.y[self.pred_probas(pos_label) >= cutoff].value_counts(normalize=percentage) - clas_df = clas_df.fillna(0) - clas_df['total'] = clas_df.sum(axis=1) - clas_df.index = self.labels - return clas_df - + def get_clas_df(cutoff, pos_label): + clas_df = pd.DataFrame(index=pd.RangeIndex(0, len(self.labels))) + clas_df['below'] = self.y[self.pred_probas(pos_label) < cutoff].value_counts() + clas_df['above'] = self.y[self.pred_probas(pos_label) >= cutoff].value_counts() + clas_df = clas_df.fillna(0) + clas_df['total'] = clas_df.sum(axis=1) + clas_df.index = self.labels + return clas_df + + if not hasattr(self, "_classification_dfs"): + _ = self.pred_probas() + print("Calculating classification_dfs...", flush=True) + self._classification_dfs = dict() + for label in range(len(self.labels)): + self._classification_dfs[label] = dict() + for cut in np.linspace(0.01, 0.99, 99): + self._classification_dfs[label][np.round(cut, 2)] = \ + get_clas_df(cut, label) + if cutoff in self._classification_dfs[pos_label]: + return self._classification_dfs[pos_label][cutoff] + else: + return get_clas_df(cutoff, pos_label) + @insert_pos_label def roc_auc_curve(self, pos_label=None): """Returns a dict with output from sklearn.metrics.roc_curve() for pos_label: @@ -2314,9 +2344,9 @@ def get_binary_cm(y, pred_probas, cutoff, pos_label): self._confusion_matrices['binary'] = dict() for label in range(len(self.labels)): self._confusion_matrices['binary'][label] = dict() - for cutoff in np.linspace(0.01, 0.99, 99): - self._confusion_matrices['binary'][label][np.round(cutoff, 2)] = \ - get_binary_cm(self.y, self.pred_probas_raw, cutoff, label) + for cut in np.linspace(0.01, 0.99, 99): + self._confusion_matrices['binary'][label][np.round(cut, 2)] = \ + get_binary_cm(self.y, self.pred_probas_raw, cut, label) self._confusion_matrices['multi'] = confusion_matrix(self.y, self.pred_probas_raw.argmax(axis=1)) if binary: if cutoff in self._confusion_matrices['binary'][pos_label]: @@ -2447,7 +2477,9 @@ def plot_classification(self, cutoff=0.5, percentage=True, pos_label=None): plotly fig """ - return plotly_classification_plot(self.get_classification_df(cutoff=cutoff, pos_label=pos_label), percentage=percentage) + return plotly_classification_plot( + self.get_classification_df(cutoff=cutoff, pos_label=pos_label), + percentage=percentage) @insert_pos_label def plot_roc_auc(self, cutoff=0.5, pos_label=None):