Skip to content

Commit

Permalink
interpret support group effects (#721)
Browse files Browse the repository at this point in the history
* [715] Adding Mr P docs

Signed-off-by: Nathaniel <[email protected]>

* [715] updating black formatting and suppressing numpyro messages

Signed-off-by: Nathaniel <[email protected]>

* [715] adding plot_comparisons checks on basic model

Signed-off-by: Nathaniel <[email protected]>

* remove Mr.P example and metadata

* support group-specific-effects default values and sampling from new groups

* use set with in for conditional statement

* remove commits related to PR #715

* add group specific effects for interpret plotting functions

---------

Signed-off-by: Nathaniel <[email protected]>
Co-authored-by: Nathaniel <[email protected]>
  • Loading branch information
GStechschulte and NathanielF authored Sep 18, 2023
1 parent 16b005b commit ab54b99
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 28 deletions.
28 changes: 24 additions & 4 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def predictions(
use_hdi: bool = True,
prob=None,
transforms=None,
sample_new_groups=False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Predictions
Expand Down Expand Up @@ -443,6 +444,9 @@ def predictions(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -496,11 +500,15 @@ def predictions(
response_transform = transforms.get(response_name, identity)

if pps:
idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False, kind="pps"
)
y_hat = response_transform(idata.posterior_predictive[response.name])
y_hat_mean = y_hat.mean(("chain", "draw"))
else:
idata = model.predict(idata, data=cap_data, inplace=False)
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
)
y_hat = response_transform(idata.posterior[response.name_target])
y_hat_mean = y_hat.mean(("chain", "draw"))

Expand Down Expand Up @@ -534,6 +542,7 @@ def comparisons(
use_hdi: bool = True,
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Comparisons
Expand Down Expand Up @@ -562,6 +571,9 @@ def comparisons(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -631,7 +643,9 @@ def comparisons(
comparisons_data = create_differences_data(
conditional_info, contrast_info, conditional_info.user_passed, kind="comparisons"
)
idata = model.predict(idata, data=comparisons_data, inplace=False)
idata = model.predict(
idata, data=comparisons_data, sample_new_groups=sample_new_groups, inplace=False
)

predictive_difference = PredictiveDifferences(
model,
Expand Down Expand Up @@ -663,6 +677,7 @@ def slopes(
use_hdi: bool = True,
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Slopes
Expand Down Expand Up @@ -702,6 +717,9 @@ def slopes(
transforms : dict, optional
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
Returns
-------
Expand Down Expand Up @@ -776,7 +794,9 @@ def slopes(
slopes_data = create_differences_data(
conditional_info, wrt_info, conditional_info.user_passed, effect_type
)
idata = model.predict(idata, data=slopes_data, inplace=False)
idata = model.predict(
idata, data=slopes_data, sample_new_groups=sample_new_groups, inplace=False
)

predictive_difference = PredictiveDifferences(
model, slopes_data, wrt_info, conditional_info, response, use_hdi, effect_type
Expand Down
15 changes: 15 additions & 0 deletions bambi/interpret/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def plot_predictions(
idata: az.InferenceData,
covariates: Union[str, list],
target: str = "mean",
sample_new_groups: bool = False,
pps: bool = False,
use_hdi: bool = True,
prob=None,
Expand All @@ -110,6 +111,9 @@ def plot_predictions(
Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
works when pps is False as the target may not be available in the posterior predictive
distribution.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
pps: bool, optional
Whether to plot the posterior predictive samples. Defaults to ``False``.
use_hdi : bool, optional
Expand Down Expand Up @@ -174,6 +178,7 @@ def plot_predictions(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

response_name = get_aliased_name(model.response_component.response_term)
Expand Down Expand Up @@ -219,6 +224,7 @@ def plot_comparisons(
conditional: Union[str, dict, list, None] = None,
average_by: Union[str, list] = None,
comparison_type: str = "diff",
sample_new_groups: bool = False,
use_hdi: bool = True,
prob=None,
legend: bool = True,
Expand All @@ -245,6 +251,9 @@ def plot_comparisons(
over the other covariates in the model. Defaults to ``None``.
comparison_type : str, optional
The type of comparison to plot. Defaults to 'diff'.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
Expand Down Expand Up @@ -332,6 +341,7 @@ def plot_comparisons(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

return _plot_differences(
Expand All @@ -355,6 +365,7 @@ def plot_slopes(
average_by: Union[str, list] = None,
eps: float = 1e-4,
slope: str = "dydx",
sample_new_groups: bool = False,
use_hdi: bool = True,
prob=None,
transforms=None,
Expand Down Expand Up @@ -395,6 +406,9 @@ def plot_slopes(
change in the response.
'dyex' represents a percent change in 'wrt' is associated with a unit increase
in the response.
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
prob : float, optional
Expand Down Expand Up @@ -486,6 +500,7 @@ def plot_slopes(
use_hdi=use_hdi,
prob=prob,
transforms=transforms,
sample_new_groups=sample_new_groups,
)

return _plot_differences(
Expand Down
40 changes: 17 additions & 23 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable = too-many-function-args
# pylint: disable = too-many-nested-blocks
from dataclasses import dataclass, field
import re
from statistics import mode
from typing import Union

Expand Down Expand Up @@ -211,17 +212,21 @@ def get_model_covariates(model: Model) -> np.ndarray:
"""

terms = get_model_terms(model)
names = []
covariates = []
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# If the component is a function call, use the argument names
if isinstance(component, Call):
names.append([arg.name for arg in component.call.args])
covariates.append([arg.name for arg in component.call.args])
else:
names.append([component.name])
covariates.append([component.name])
elif hasattr(term, "factor"):
covariates.append(list(term.var_names))

return np.unique(names)
flatten_covariates = [item for sublist in covariates for item in sublist]

return np.unique(flatten_covariates)


def get_covariates(covariates: dict) -> Covariates:
Expand Down Expand Up @@ -330,25 +335,14 @@ def set_default_values(model: Model, data_dict: dict, kind: str) -> dict:
"slopes",
), "kind must be either 'comparisons', 'slopes', or 'predictions'"

terms = get_model_terms(model)

# Get default values for each variable in the model
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# If the component is a function call, use the argument names
if isinstance(component, Call):
names = [arg.name for arg in component.call.args]
else:
names = [component.name]
for name in names:
if name not in data_dict:
# For numeric predictors, select the mean.
if component.kind == "numeric":
data_dict[name] = np.mean(model.data[name])
# For categoric predictors, select the most frequent level.
elif component.kind == "categoric":
data_dict[name] = mode(model.data[name])
unique_covariates = get_model_covariates(model)
for name in unique_covariates:
if name not in data_dict:
dtype = str(model.data[name].dtype)
if re.match(r"float*|int*", dtype):
data_dict[name] = np.mean(model.data[name])
elif dtype in ("category", "dtype"):
data_dict[name] = mode(model.data[name])

if kind in ("comparisons", "slopes"):
# if value in dict is not a list then convert to a list
Expand Down
63 changes: 62 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,23 @@

@pytest.fixture(scope="module")
def mtcars():
"Model with common level effects only"
data = bmb.load_data('mtcars')
data["am"] = pd.Categorical(data["am"], categories=[0, 1], ordered=True)
model = bmb.Model("mpg ~ hp * drat * am", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


@pytest.fixture(scope="module")
def sleep_study():
"Model with common and group specific effects"
data = bmb.load_data('sleepstudy')
model = bmb.Model("Reaction ~ 1 + Days + (Days | Subject)", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Improvement:
# * Test the actual plots are what we are indeed the desired result.
# * Test using the dictionary and the list gives the same plot
Expand Down Expand Up @@ -224,6 +234,19 @@ def test_multiple_outputs_with_alias(self, pps):

# Test user supplied target argument
plot_predictions(model, idata, "x", "alpha", pps=False)


def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_predictions(model, idata, ["Days", "Subject"], sample_new_groups=True)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_predictions(model, idata, ["Days", "Subject"])


class TestComparison:
Expand Down Expand Up @@ -294,6 +317,25 @@ def test_average_by(self, mtcars, average_by):

# unit level with average by
plot_comparisons(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_comparisons(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_comparisons(
model,
idata,
contrast={"Days": [2, 4]},
conditional={"Subject": [308, 335, 352, 372]},
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_comparisons(model, idata, "Days", "Subject")


class TestSlopes:
Expand Down Expand Up @@ -374,4 +416,23 @@ def test_average_by(self, mtcars, average_by):
plot_slopes(model, idata, "hp", ["am", "drat"], average_by)

# unit level with average by
plot_slopes(model, idata, "hp", None, average_by)
plot_slopes(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_slopes(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_slopes(
model,
idata,
wrt={"Days": 2},
conditional={"Subject": 308}
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_slopes(model, idata, "Days", "Subject")

0 comments on commit ab54b99

Please sign in to comment.