From 01bce0400ec6278a7d8418d415478113bbd24a48 Mon Sep 17 00:00:00 2001 From: Oege Dijk Date: Mon, 11 Mar 2024 22:28:49 +0100 Subject: [PATCH] fix shap_interaction_values for shap 0.45 --- explainerdashboard/explainers.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index 54315a1..33a1e23 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -3068,9 +3068,21 @@ def shap_interaction_values(self, pos_label=None): self._shap_interaction_values = self.shap_explainer.shap_interaction_values( self.X ) - if len(self.labels) == 2: - if not isinstance(self._shap_interaction_values, list): + if ( + isinstance(self._shap_interaction_values, np.ndarray) + and len(self._shap_interaction_values.shape) == 4 + and self._shap_interaction_values.shape[3] == 2 + ): + # for binary classifier only keep positive class: + self._shap_interaction_values = [ + self._shap_interaction_values[:, :, :, 1] + ] + elif ( + isinstance(self._shap_interaction_values, np.ndarray) + and len(self._shap_interaction_values.shape) == 3 + ): + # for binary classifier only keep positive class: self._shap_interaction_values = [self._shap_interaction_values] elif ( isinstance(self._shap_interaction_values, list) @@ -3086,6 +3098,15 @@ def shap_interaction_values(self, pos_label=None): "Adjust the labels parameter accordingly!" ) else: + if ( + isinstance(self._shap_interaction_values, np.ndarray) + and len(self._shap_interaction_values.shape) == 4 + and self._shap_interaciton_values.shape[3] > 2 + ): + self._shap_interaciton_values = [ + self._shap_interaction_values[:, :, :, i] + for i in range(self._shap_interaciton_values.shape) + ] assert len(self._shap_interaction_values) == len(self.labels), ( f"len(self.label)={len(self.labels)}, but " f"shap returned shap values for {len(self._shap_interaction_values)} classes! "