Skip to content

Commit

Permalink
update interpret-community to shap 0.43.0 (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Oct 26, 2023
1 parent 6803ca8 commit 31a3eb2
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
pip install raiwidgets
pip install -r requirements-vis.txt
pip install --upgrade scikit-learn
pip install --upgrade "shap<=0.42.1"
pip install --upgrade "shap<=0.43.0"
- name: Install test dependencies
shell: bash -l {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion devops/templates/test-run-step-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ steps:
pip install responsibleai
pip install rai-core-flask==0.5.0
pip install raiwidgets --no-deps
pip install --upgrade "shap<=0.42.1"
pip install --upgrade "shap<=0.43.0"
pip install -r requirements-vis.txt
displayName: Install vis required pip packages

Expand Down
1 change: 1 addition & 0 deletions python/interpret_community/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ExplainParams(object):
"""Provide constants for interpret community (init, explain_local and explain_global) parameters."""

BATCH_SIZE = 'batch_size'
CHECK_ADDITIVITY = 'check_additivity'
CLASSES = 'classes'
CLASSIFICATION = 'classification'
EVAL_DATA = 'eval_data'
Expand Down
19 changes: 15 additions & 4 deletions python/interpret_community/shap/deep_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@ class DeepExplainer(StructuredInitModelExplainer):
:type allow_all_transformations: bool
:param model_task: Optional parameter to specify whether the model is a classification or regression model.
:type model_task: str
:param is_classifier: Optional parameter to specify whether the model is a classification or regression model.
In most cases, the type of the model can be inferred based on the shape of the output, where a classifier
has a predict_proba method and outputs a 2 dimensional array, while a regressor has a predict method and
outputs a 1 dimensional array.
:type is_classifier: bool
:param check_additivity: Optional parameter to specify whether to check the additivity of the SHAP values.
:type check_additivity: bool
"""

@init_tabular_decorator
@init_aggregator_decorator
def __init__(self, model, initialization_examples, explain_subset=None, nclusters=10,
features=None, classes=None, transformations=None, allow_all_transformations=False,
model_task=ModelTask.Unknown, is_classifier=None, **kwargs):
model_task=ModelTask.Unknown, is_classifier=None, check_additivity=True, **kwargs):
"""Initialize the DeepExplainer.
:param model: The DNN model to explain.
Expand Down Expand Up @@ -253,13 +260,15 @@ def __init__(self, model, initialization_examples, explain_subset=None, ncluster
:type transformations: sklearn.compose.ColumnTransformer or list[tuple]
:param allow_all_transformations: Allow many to many and many to one transformations
:type allow_all_transformations: bool
:param model_task: Optional parameter to specify whether the model is a classification or regression model.
:type model_task: str
:param is_classifier: Optional parameter to specify whether the model is a classification or regression model.
In most cases, the type of the model can be inferred based on the shape of the output, where a classifier
has a predict_proba method and outputs a 2 dimensional array, while a regressor has a predict method and
outputs a 1 dimensional array.
:type is_classifier: bool
:param model_task: Optional parameter to specify whether the model is a classification or regression model.
:type model_task: str
:param check_additivity: Optional parameter to specify whether to check the additivity of the SHAP values.
:type check_additivity: bool
"""
self._datamapper = None
if transformations is not None:
Expand All @@ -277,6 +286,7 @@ def __init__(self, model, initialization_examples, explain_subset=None, ncluster
self.transformations = transformations
self.model_task = model_task
self.framework = _get_dnn_model_framework(self.model)
self._check_additivity = check_additivity
summary = _get_summary_data(self.initialization_examples, nclusters, self.framework)
# Suppress warning message from Keras
with logger_redirector(self._logger):
Expand Down Expand Up @@ -340,7 +350,8 @@ def _get_explain_local_kwargs(self, evaluation_examples):
dense_examples = _get_dense_examples(evaluation_examples)
if self.framework == DNNFramework.PYTORCH:
dense_examples = torch.Tensor(dense_examples)
shap_values = self.explainer.shap_values(dense_examples)
shap_values = self.explainer.shap_values(
dense_examples, check_additivity=self._check_additivity)
# use model task to update structure of shap values
single_output = isinstance(shap_values, list) and len(shap_values) == 1
if single_output:
Expand Down
4 changes: 4 additions & 0 deletions python/interpret_community/tabular_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def __init__(self, model, initialization_examples, explain_subset=None, features
kwargs[ExplainParams.MODEL_TASK] = model_task
else:
kwargs.pop(ExplainParams.MODEL_TASK, None)
if uninitialized_explainer == DeepExplainer:
kwargs[ExplainParams.CHECK_ADDITIVITY] = False
else:
kwargs.pop(ExplainParams.CHECK_ADDITIVITY, None)
self.explainer = uninitialized_explainer(
self.model, initialization_examples, transformations=transformations,
allow_all_transformations=allow_all_transformations,
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'scikit-learn',
'packaging',
'interpret-core[required]>=0.1.20, <=0.4.4',
'shap>=0.20.0, <=0.42.1',
'shap>=0.20.0, <=0.43.0',
'raiutils~=0.4.0'
]

Expand Down

0 comments on commit 31a3eb2

Please sign in to comment.