Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for plotting order in plot_comparison (fixes mismatch between labels and plot) #731

Merged
merged 20 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def build(self, bmb_model):
contribution = pt.concatenate(contribution_list)[indexes_to_unsort]
# If there are no groups, it's a single dot product
else:
contribution = phi @ coeffs
contribution = pt.dot(phi, coeffs) # "@" operator is not working as expected

output = pm.Deterministic(label, contribution, dims=contribution_dims)
return output
Expand Down
5 changes: 5 additions & 0 deletions bambi/interpret/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def _grid_level(
elif kind == "slopes":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid, variable_info.name)

# After computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
pairwise_grid = pairwise_grid.drop_duplicates()

return pairwise_grid


Expand Down
16 changes: 8 additions & 8 deletions bambi/interpret/plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd

from bambi.interpret.utils import Covariates, get_unique_levels, get_group_offset, identity
from bambi.interpret.utils import Covariates, get_group_offset, identity


def plot_numeric(
Expand Down Expand Up @@ -49,7 +49,7 @@ def plot_numeric(
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
elif "group" in covariates and not "panel" in covariates:
ax = axes[0]
colors = get_unique_levels(plot_data[color])
colors = np.unique(plot_data[color])
for i, clr in enumerate(colors):
idx = (plot_data[color] == clr).to_numpy()
values_main = transform_main(plot_data.loc[idx, main])
Expand All @@ -62,16 +62,16 @@ def plot_numeric(
color=f"C{i}",
)
elif not "group" in covariates and "panel" in covariates:
panels = get_unique_levels(plot_data[panel])
panels = np.unique(plot_data[panel])
for ax, pnl in zip(axes.ravel(), panels):
idx = (plot_data[panel] == pnl).to_numpy()
values_main = transform_main(plot_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4)
ax.set(title=f"{panel} = {pnl}")
elif "group" in covariates and "panel" in covariates:
colors = get_unique_levels(plot_data[color])
panels = get_unique_levels(plot_data[panel])
colors = np.unique(plot_data[color])
panels = np.unique(plot_data[panel])
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (plot_data[panel] == pnl).to_numpy()
Expand Down Expand Up @@ -138,20 +138,20 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool

main, color, panel = covariates.main, covariates.group, covariates.panel
covariates = {k: v for k, v in vars(covariates).items() if v is not None}
main_levels = get_unique_levels(plot_data[main])
main_levels = np.unique(plot_data[main])
main_levels_n = len(main_levels)
idxs_main = np.arange(main_levels_n)
y_hat_mean = plot_data["estimate"]
y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)

if "group" in covariates:
colors = get_unique_levels(plot_data[color])
colors = np.unique(plot_data[color])
colors_n = len(colors)
offset_bounds = get_group_offset(colors_n)
colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)

if "panel" in covariates:
panels = get_unique_levels(plot_data[panel])
panels = np.unique(plot_data[panel])

if len(covariates) == 1:
ax = axes[0]
Expand Down
34 changes: 20 additions & 14 deletions bambi/interpret/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,16 @@ def plot_comparisons(
if conditional is None and average_by is None:
raise ValueError("Must specify at least one of 'conditional' or 'average_by'.")

if conditional is not None:
if not isinstance(conditional, str):
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)
if isinstance(conditional, dict):
conditional = {key: sorted(listify(value)) for key, value in conditional.items()}
elif conditional is not None:
conditional = listify(conditional)
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)

if average_by is True:
raise ValueError(
"Plotting when 'average_by = True' is not possible as 'True' marginalizes "
Expand Down Expand Up @@ -469,13 +472,16 @@ def plot_slopes(

if conditional is None and average_by is None:
raise ValueError("Must specify at least one of 'conditional' or 'average_by'.")
if conditional is not None:
if not isinstance(conditional, str):
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)

if isinstance(conditional, dict):
conditional = {key: sorted(listify(value)) for key, value in conditional.items()}
elif conditional is not None:
conditional = listify(conditional)
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)

if average_by is True:
raise ValueError(
Expand Down
13 changes: 1 addition & 12 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def set_default_variable_values(self) -> np.ndarray:
predictor_data, self.eps, dtype
)
elif component.kind == "categoric":
values = get_unique_levels(predictor_data)
values = np.unique(predictor_data)

return values

Expand Down Expand Up @@ -377,17 +377,6 @@ def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray:
raise ValueError("Group covariate must be numeric or categoric.")


def get_unique_levels(x: np.ndarray) -> np.ndarray:
"""
Get unique levels of a categoric variable.
"""
if hasattr(x, "dtype") and hasattr(x.dtype, "categories"):
levels = np.array((x.dtype.categories))
else:
levels = np.unique(x)
return levels


def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:
# Complementary log log function, scaled.
# See following code to have an idea of how this function looks like
Expand Down
Loading