Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jt-lab/bambi
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Oct 3, 2023
2 parents 3b319ff + cdd40df commit 1387d44
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
5 changes: 5 additions & 0 deletions bambi/interpret/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def _grid_level(
elif kind == "slopes":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid, variable_info.name)

# After computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
pairwise_grid = pairwise_grid.drop_duplicates()

return pairwise_grid


Expand Down
16 changes: 8 additions & 8 deletions bambi/interpret/plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd

from bambi.interpret.utils import Covariates, get_unique_levels, get_group_offset, identity
from bambi.interpret.utils import Covariates, get_group_offset, identity


def plot_numeric(
Expand Down Expand Up @@ -49,7 +49,7 @@ def plot_numeric(
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
elif "group" in covariates and not "panel" in covariates:
ax = axes[0]
colors = get_unique_levels(plot_data[color])
colors = np.unique(plot_data[color])
for i, clr in enumerate(colors):
idx = (plot_data[color] == clr).to_numpy()
values_main = transform_main(plot_data.loc[idx, main])
Expand All @@ -62,16 +62,16 @@ def plot_numeric(
color=f"C{i}",
)
elif not "group" in covariates and "panel" in covariates:
panels = get_unique_levels(plot_data[panel])
panels = np.unique(plot_data[panel])
for ax, pnl in zip(axes.ravel(), panels):
idx = (plot_data[panel] == pnl).to_numpy()
values_main = transform_main(plot_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4)
ax.set(title=f"{panel} = {pnl}")
elif "group" in covariates and "panel" in covariates:
colors = get_unique_levels(plot_data[color])
panels = get_unique_levels(plot_data[panel])
colors = np.unique(plot_data[color])
panels = np.unique(plot_data[panel])
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (plot_data[panel] == pnl).to_numpy()
Expand Down Expand Up @@ -138,20 +138,20 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool

main, color, panel = covariates.main, covariates.group, covariates.panel
covariates = {k: v for k, v in vars(covariates).items() if v is not None}
main_levels = get_unique_levels(plot_data[main])
main_levels = np.unique(plot_data[main])
main_levels_n = len(main_levels)
idxs_main = np.arange(main_levels_n)
y_hat_mean = plot_data["estimate"]
y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)

if "group" in covariates:
colors = get_unique_levels(plot_data[color])
colors = np.unique(plot_data[color])
colors_n = len(colors)
offset_bounds = get_group_offset(colors_n)
colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)

if "panel" in covariates:
panels = get_unique_levels(plot_data[panel])
panels = np.unique(plot_data[panel])

if len(covariates) == 1:
ax = axes[0]
Expand Down
6 changes: 6 additions & 0 deletions bambi/interpret/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ def plot_comparisons(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)
if isinstance(conditional, dict):
conditional = {key: sorted(value) for key, value in conditional.items()}

if average_by is True:
raise ValueError(
"Plotting when 'average_by = True' is not possible as 'True' marginalizes "
Expand Down Expand Up @@ -469,13 +472,16 @@ def plot_slopes(

if conditional is None and average_by is None:
raise ValueError("Must specify at least one of 'conditional' or 'average_by'.")

if conditional is not None:
if not isinstance(conditional, str):
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)
if isinstance(conditional, dict):
conditional = {key: sorted(value) for key, value in conditional.items()}

if average_by is True:
raise ValueError(
Expand Down
13 changes: 1 addition & 12 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def set_default_variable_values(self) -> np.ndarray:
predictor_data, self.eps, dtype
)
elif component.kind == "categoric":
values = get_unique_levels(predictor_data)
values = np.unique(predictor_data)

return values

Expand Down Expand Up @@ -377,17 +377,6 @@ def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray:
raise ValueError("Group covariate must be numeric or categoric.")


def get_unique_levels(x: np.ndarray) -> np.ndarray:
"""
Get unique levels of a categoric variable.
"""
if hasattr(x, "dtype") and hasattr(x.dtype, "categories"):
levels = np.array((x.dtype.categories))
else:
levels = np.unique(x)
return levels


def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:
# Complementary log log function, scaled.
# See following code to have an idea of how this function looks like
Expand Down

0 comments on commit 1387d44

Please sign in to comment.