Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpret support group effects #721

Merged
merged 8 commits into from
Sep 18, 2023
28 changes: 24 additions & 4 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def predictions(
use_hdi: bool = True,
prob=None,
transforms=None,
sample_new_groups=False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Predictions

Expand Down Expand Up @@ -443,6 +444,9 @@ def predictions(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.

Returns
-------
Expand Down Expand Up @@ -496,11 +500,15 @@ def predictions(
response_transform = transforms.get(response_name, identity)

if pps:
idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False, kind="pps"
)
y_hat = response_transform(idata.posterior_predictive[response.name])
y_hat_mean = y_hat.mean(("chain", "draw"))
else:
idata = model.predict(idata, data=cap_data, inplace=False)
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
)
y_hat = response_transform(idata.posterior[response.name_target])
y_hat_mean = y_hat.mean(("chain", "draw"))

Expand Down Expand Up @@ -534,6 +542,7 @@ def comparisons(
use_hdi: bool = True,
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Comparisons

Expand Down Expand Up @@ -562,6 +571,9 @@ def comparisons(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.

Returns
-------
Expand Down Expand Up @@ -631,7 +643,9 @@ def comparisons(
comparisons_data = create_differences_data(
conditional_info, contrast_info, conditional_info.user_passed, kind="comparisons"
)
idata = model.predict(idata, data=comparisons_data, inplace=False)
idata = model.predict(
idata, data=comparisons_data, sample_new_groups=sample_new_groups, inplace=False
)

predictive_difference = PredictiveDifferences(
model,
Expand Down Expand Up @@ -663,6 +677,7 @@ def slopes(
use_hdi: bool = True,
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Slopes

Expand Down Expand Up @@ -702,6 +717,9 @@ def slopes(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.

Returns
-------
Expand Down Expand Up @@ -776,7 +794,9 @@ def slopes(
slopes_data = create_differences_data(
conditional_info, wrt_info, conditional_info.user_passed, effect_type
)
idata = model.predict(idata, data=slopes_data, inplace=False)
idata = model.predict(
idata, data=slopes_data, sample_new_groups=sample_new_groups, inplace=False
)

predictive_difference = PredictiveDifferences(
model, slopes_data, wrt_info, conditional_info, response, use_hdi, effect_type
Expand Down
15 changes: 15 additions & 0 deletions bambi/interpret/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def plot_predictions(
idata: az.InferenceData,
covariates: Union[str, list],
target: str = "mean",
sample_new_groups: bool = False,
pps: bool = False,
use_hdi: bool = True,
prob=None,
Expand All @@ -110,6 +111,9 @@ def plot_predictions(
Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
works when pps is False as the target may not be available in the posterior predictive
distribution.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
pps: bool, optional
Whether to plot the posterior predictive samples. Defaults to ``False``.
use_hdi : bool, optional
Expand Down Expand Up @@ -174,6 +178,7 @@ def plot_predictions(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

response_name = get_aliased_name(model.response_component.response_term)
Expand Down Expand Up @@ -219,6 +224,7 @@ def plot_comparisons(
conditional: Union[str, dict, list, None] = None,
average_by: Union[str, list] = None,
comparison_type: str = "diff",
sample_new_groups: bool = False,
use_hdi: bool = True,
prob=None,
legend: bool = True,
Expand All @@ -245,6 +251,9 @@ def plot_comparisons(
over the other covariates in the model. Defaults to ``None``.
comparison_type : str, optional
The type of comparison to plot. Defaults to 'diff'.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
Expand Down Expand Up @@ -332,6 +341,7 @@ def plot_comparisons(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

return _plot_differences(
Expand All @@ -355,6 +365,7 @@ def plot_slopes(
average_by: Union[str, list] = None,
eps: float = 1e-4,
slope: str = "dydx",
sample_new_groups: bool = False,
use_hdi: bool = True,
prob=None,
transforms=None,
Expand Down Expand Up @@ -395,6 +406,9 @@ def plot_slopes(
change in the response.
'dyex' represents a percent change in 'wrt' is associated with a unit increase
in the response.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
Expand Down Expand Up @@ -486,6 +500,7 @@ def plot_slopes(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

return _plot_differences(
Expand Down
40 changes: 17 additions & 23 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable = too-many-function-args
# pylint: disable = too-many-nested-blocks
from dataclasses import dataclass, field
import re
from statistics import mode
from typing import Union

Expand Down Expand Up @@ -211,17 +212,21 @@ def get_model_covariates(model: Model) -> np.ndarray:
"""

terms = get_model_terms(model)
names = []
covariates = []
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# If the component is a function call, use the argument names
if isinstance(component, Call):
names.append([arg.name for arg in component.call.args])
covariates.append([arg.name for arg in component.call.args])
else:
names.append([component.name])
covariates.append([component.name])
elif hasattr(term, "factor"):
covariates.append(list(term.var_names))

return np.unique(names)
flatten_covariates = [item for sublist in covariates for item in sublist]

return np.unique(flatten_covariates)


def get_covariates(covariates: dict) -> Covariates:
Expand Down Expand Up @@ -330,25 +335,14 @@ def set_default_values(model: Model, data_dict: dict, kind: str) -> dict:
"slopes",
), "kind must be either 'comparisons', 'slopes', or 'predictions'"

terms = get_model_terms(model)

# Get default values for each variable in the model
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# If the component is a function call, use the argument names
if isinstance(component, Call):
names = [arg.name for arg in component.call.args]
else:
names = [component.name]
for name in names:
if name not in data_dict:
# For numeric predictors, select the mean.
if component.kind == "numeric":
data_dict[name] = np.mean(model.data[name])
# For categoric predictors, select the most frequent level.
elif component.kind == "categoric":
data_dict[name] = mode(model.data[name])
unique_covariates = get_model_covariates(model)
for name in unique_covariates:
if name not in data_dict:
dtype = str(model.data[name].dtype)
if re.match(r"float*|int*", dtype):
data_dict[name] = np.mean(model.data[name])
elif dtype in ("category", "dtype"):
data_dict[name] = mode(model.data[name])

if kind in ("comparisons", "slopes"):
# if value in dict is not a list then convert to a list
Expand Down
63 changes: 62 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,23 @@

@pytest.fixture(scope="module")
def mtcars():
"Model with common level effects only"
data = bmb.load_data('mtcars')
data["am"] = pd.Categorical(data["am"], categories=[0, 1], ordered=True)
model = bmb.Model("mpg ~ hp * drat * am", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


@pytest.fixture(scope="module")
def sleep_study():
"Model with common and group specific effects"
data = bmb.load_data('sleepstudy')
model = bmb.Model("Reaction ~ 1 + Days + (Days | Subject)", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Improvement:
# * Test the actual plots are what we are indeed the desired result.
# * Test using the dictionary and the list gives the same plot
Expand Down Expand Up @@ -224,6 +234,19 @@ def test_multiple_outputs_with_alias(self, pps):

# Test user supplied target argument
plot_predictions(model, idata, "x", "alpha", pps=False)


def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_predictions(model, idata, ["Days", "Subject"], sample_new_groups=True)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_predictions(model, idata, ["Days", "Subject"])


class TestComparison:
Expand Down Expand Up @@ -294,6 +317,25 @@ def test_average_by(self, mtcars, average_by):

# unit level with average by
plot_comparisons(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_comparisons(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_comparisons(
model,
idata,
contrast={"Days": [2, 4]},
conditional={"Subject": [308, 335, 352, 372]},
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_comparisons(model, idata, "Days", "Subject")


class TestSlopes:
Expand Down Expand Up @@ -374,4 +416,23 @@ def test_average_by(self, mtcars, average_by):
plot_slopes(model, idata, "hp", ["am", "drat"], average_by)

# unit level with average by
plot_slopes(model, idata, "hp", None, average_by)
plot_slopes(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_slopes(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_slopes(
model,
idata,
wrt={"Days": 2},
conditional={"Subject": 308}
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_slopes(model, idata, "Days", "Subject")