Skip to content

Commit

Permalink
adds set_shap_values methods
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Oct 24, 2021
1 parent ff25893 commit 213238f
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
4 changes: 3 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
### New Features
- Export your ExplainerHub to static html with `hub.to_html()` and `hub.save_html()` methods
- Export your ExplainerHub to a zip file with static html exports with `to_zip()` method
- Manually add pre-calculated shap values with `explainer.set_shap_values()`
- Manually add pre-calculated shap interaction values with `explainer.set_shap_interaction_values()`

### Bug Fixes
-
- Fixed bug with What if tab components static html export (missing `</div>`)
-

### Improvements
Expand Down
118 changes: 118 additions & 0 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,24 @@ def get_shap_values_df(self, pos_label=None):
self._shap_values_df, self.onehot_dict, self.merged_cols).astype(self.precision)
return self._shap_values_df

def set_shap_values(self, base_value:float, shap_values:np.ndarray):
"""Set shap values manually. This is useful if you already have
shap values calculated, and do not want to calculate them again inside
the explainer instance. Especially for large models and large datasets
you may want to calculate shap values on specialized hardware, and then
add them to the explainer manually.
Args:
base_value (float): the shap intercept generated by e.g.
base_value = shap.TreeExplainer(model).shap_values(X_test).expected_value
shap_values (np.ndarray]): Generated by e.g.
shap_values = shap.TreeExplainer(model).shap_values(X_test)
"""
self._shap_base_value = base_value
self._shap_values_df = pd.DataFrame(shap_values, columns=self.columns)
self._shap_values_df = merge_categorical_shap_values(
self._shap_values_df, self.onehot_dict, self.merged_cols).astype(self.precision)

@insert_pos_label
def get_shap_row(self, index=None, X_row=None, pos_label=None):
if index is not None:
Expand Down Expand Up @@ -1018,6 +1036,27 @@ def shap_interaction_values(self, pos_label=None):
self.onehot_dict).astype(self.precision)
return self._shap_interaction_values

def set_shap_interaction_values(self, shap_interaction_values:np.ndarray):
"""Manually set shap interaction values in case you have already pre-computed
these elsewhere and do not want to re-calculate them again inside the
explainer instance.
Args:
shap_interaction_values (np.ndarray): shap interactions values of shape (n, m, m)
"""
if not isinstance(shap_interaction_values, np.ndarray):
raise ValueError("shap_interaction_values should be a numpy array")
if not shap_interaction_values.shape == (len(self.X), len(self.original_cols), len(self.original_cols)):
raise ValueError("shap interaction_values should be of shape "
f"({len(self.X)}, {len(self.original_cols)}, {len(self.original_cols)})!")

self._shap_interaction_values = \
merge_categorical_shap_interaction_values(
shap_interaction_values, self.columns, self.merged_cols,
self.onehot_dict).astype(self.precision)


@insert_pos_label
def mean_abs_shap_df(self, pos_label=None):
"""Mean absolute SHAP values per feature."""
Expand Down Expand Up @@ -2328,6 +2367,58 @@ def get_shap_values_df(self, pos_label=None):
else:
raise ValueError(f"pos_label={pos_label}, but should be either 1 or 0!")

def set_shap_values(self, base_value:List[float], shap_values:List):
"""Set shap values manually. This is useful if you already have
shap values calculated, and do not want to calculate them again inside
the explainer instance. Especially for large models and large datasets
you may want to calculate shap values on specialized hardware, and then
add them to the explainer manually.
Args:
base_value (list[float]): list of shap intercept generated by e.g.
base_value = shap.TreeExplainer(model).shap_values(X_test).expected_value.
Should be a list with a float for each class. For binary classification
and some models shap only provides the base value for the positive class,
in which case you need to provide [1-base_value, base_value] or [-base_value, base_value]
depending on whether the shap values are for probabilities or logodds.
shap_values (list[np.ndarray]): Generated by e.g.
shap_values = shap.TreeExplainer(model).shap_values(X_test)
For binary classification
and some models shap only provides the shap values for the positive class,
in which case you need to provide [1-shap_values, shap_values] or [-shap_values, shap_values]
depending on whether the shap values are for probabilities or logodds.
"""
if isinstance(base_value, np.ndarray) and base_value.shape == (len(self.labels),):
base_value = list(base_value)
if not isinstance(base_value, list):
raise ValueError("base_value should be a list of floats with an expected value for each class")
if not len(base_value) == len(self.labels):
raise ValueError("base value should be a list with an expected value "
f"for each class, so should be length {len(self.labels)}")
self._shap_base_value = base_value

self._shap_values_df = []
if not isinstance(shap_values, list):
raise ValueError("shap_values should be a list of np.ndarray with shap values for each class")
if len(shap_values) != len(self.labels):
raise ValueError("shap_values be a list with a np.ndarray of shap values "
f"for each class, so should be length {len(self.labels)}")
for sv in shap_values:
if not isinstance(sv, np.ndarray):
raise ValueError("each element of shap_values should be an np.ndarray!")
if sv.shape[0] != len(self.X):
raise ValueError(f"Expected shap values to have {len(self.X)} rows!")
if sv.shape[1] != len(self.original_cols):
raise ValueError(f"Expected shap values to have {len(self.original_columns)} columns!")
self._shap_values_df.append(
merge_categorical_shap_values(
pd.DataFrame(sv, columns=self.columns),
self.onehot_dict, self.merged_cols).astype(self.precision)
)
if len(self.labels) == 2:
self._shap_values_df = self._shap_values_df[1]


@insert_pos_label
def get_shap_row(self, index=None, X_row=None, pos_label=None):
def X_row_to_shap_row(X_row):
Expand Down Expand Up @@ -2418,6 +2509,33 @@ def shap_interaction_values(self, pos_label=None):
else:
raise ValueError(f"pos_label={pos_label}, but should be either 1 or 0!")

def set_shap_interaction_values(self, shap_interaction_values:List[np.ndarray]):
"""Manually set shap interaction values in case you have already pre-computed
these elsewhere and do not want to re-calculate them again inside the
explainer instance.
Args:
shap_interaction_values (np.ndarray): shap interactions values of shape (n, m, m)
"""
self._shap_interaction_values = []
if not isinstance(shap_interaction_values, list):
raise ValueError("shap_interaction_values should be a list of np.ndarray with shap interaction values for each class")
if len(shap_interaction_values) != len(self.labels):
raise ValueError("shap_interaction_values should be a list with a np.ndarray of shap interaction values "
f"for each class, so should be length {len(self.labels)}")
for siv in shap_interaction_values:
if not isinstance(siv, np.ndarray):
raise ValueError("each element of shap_values should be an np.ndarray!")
if siv.shape != (len(self.X), len(self.original_cols), len(self.original_cols)):
raise ValueError(f"Expected shap interaction values to have shape of "
f"({len(self.X)}, {len(self.original_cols)}, {len(self.original_cols)})")
self._shap_interaction_values.append(
merge_categorical_shap_interaction_values(
siv, self.columns, self.merged_cols, self.onehot_dict).astype(self.precision))
if len(self.labels) == 2:
self._shap_interaction_values = self._shap_interaction_values[1]

@insert_pos_label
def mean_abs_shap_df(self, pos_label=None):
"""mean absolute SHAP values"""
Expand Down

0 comments on commit 213238f

Please sign in to comment.