Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
snehankekre committed Feb 1, 2022
1 parent 6ab2975 commit 1c05836
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# streamlit-shap

This component provides a wrapper to display [SHAP](https://github.com/slundberg/shap) plots in [Streamlit](https://streamlit.io/).

### Installation

`pip install git+https://github.com/snehankekre/streamlit-shap.git`
Expand Down Expand Up @@ -57,7 +59,9 @@ st_shap(shap.plots.beeswarm(shap_values), height=300)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:]), height=200)
st_shap(shap.force_plot(explainer.expected_value, shap_values[:1000,:], X_display.iloc[:1000,:]), height=500)
st_shap(shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:]), height=200, width=1000)
st_shap(shap.force_plot(explainer.expected_value, shap_values[:1000,:], X_display.iloc[:1000,:]), height=400, width=1000)

```

![st_shap](example.gif)
Binary file added example.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions streamlit_shap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from io import BytesIO


def st_shap(plot, height=None):
def st_shap(plot, height=None, width=None):
"""Takes a SHAP plot as input, and returns a streamlit.delta_generator.DeltaGenerator as output.
It is recommended to set the height and omit the width
It is recommended to set the height and width
parameter to have the plot fit to the window.
Parameters
Expand Down Expand Up @@ -82,7 +82,7 @@ def st_shap(plot, height=None):
elif hasattr(plot, "html") or hasattr(plot, "data") or hasattr(plot, "matplotlib"):

shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
fig = components.html(shap_html, height=height)
fig = components.html(shap_html, height=height, width=width)

else:
fig = components.html(
Expand Down

0 comments on commit 1c05836

Please sign in to comment.