Skip to content

Commit

Permalink
raising ValueError when y_missing
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Nov 18, 2020
1 parent 921c494 commit dac44ce
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,8 +1717,11 @@ def get_prop_for_label(self, prop:str, label):
def y_binary(self):
"""for multiclass problems returns one-vs-rest array of [1,0] pos_label"""
if not hasattr(self, '_y_binaries'):
self._y_binaries = [np.where(self.y.values==i, 1, 0)
for i in range(self.y.nunique())]
if not self.y_missing:
self._y_binaries = [np.where(self.y.values==i, 1, 0)
for i in range(self.y.nunique())]
else:
self._y_binaries = [self.y.values for i in range(len(self.labels))]
return default_list(self._y_binaries, self.pos_label)

@property
Expand Down Expand Up @@ -1945,6 +1948,9 @@ def metrics(self, cutoff=0.5, pos_label=None):
dict
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot calculate metrics!")

if pos_label is None: pos_label = self.pos_label
metrics_dict = {
'accuracy' : accuracy_score(self.y_binary(pos_label), np.where(self.pred_probas(pos_label) > cutoff, 1, 0)),
Expand Down Expand Up @@ -2090,6 +2096,8 @@ def precision_df(self, bin_size=None, quantiles=None, multiclass=False,
pd.DataFrame: precision_df
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot calculate precision_df!")
assert self.pred_probas is not None

if pos_label is None: pos_label = self.pos_label
Expand Down Expand Up @@ -2223,6 +2231,8 @@ def plot_confusion_matrix(self, cutoff=0.5, normalized=False, binary=False, pos_
plotly fig
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot confusion matrix!")
if pos_label is None: pos_label = self.pos_label
pos_label_str = self.labels[pos_label]

Expand Down Expand Up @@ -2288,6 +2298,8 @@ def plot_roc_auc(self, cutoff=0.5, pos_label=None):
Returns:
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot roc auc!")
return plotly_roc_auc_curve(self.y_binary(pos_label), self.pred_probas(pos_label), cutoff=cutoff)

def plot_pr_auc(self, cutoff=0.5, pos_label=None):
Expand All @@ -2302,6 +2314,8 @@ def plot_pr_auc(self, cutoff=0.5, pos_label=None):
Returns:
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot plot PR AUC!")
return plotly_pr_auc_curve(self.y_binary(pos_label), self.pred_probas(pos_label), cutoff=cutoff)

def calculate_properties(self, include_interactions=True):
Expand Down

0 comments on commit dac44ce

Please sign in to comment.