Skip to content

Commit

Permalink
add group specific effects for interpret plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Sep 15, 2023
1 parent 3cf30f4 commit abed003
Showing 1 changed file with 62 additions and 1 deletion.
63 changes: 62 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,23 @@

@pytest.fixture(scope="module")
def mtcars():
"Model with common level effects only"
data = bmb.load_data('mtcars')
data["am"] = pd.Categorical(data["am"], categories=[0, 1], ordered=True)
model = bmb.Model("mpg ~ hp * drat * am", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


@pytest.fixture(scope="module")
def sleep_study():
"Model with common and group specific effects"
data = bmb.load_data('sleepstudy')
model = bmb.Model("Reaction ~ 1 + Days + (Days | Subject)", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Improvement:
# * Test the actual plots are what we are indeed the desired result.
# * Test using the dictionary and the list gives the same plot
Expand Down Expand Up @@ -224,6 +234,19 @@ def test_multiple_outputs_with_alias(self, pps):

# Test user supplied target argument
plot_predictions(model, idata, "x", "alpha", pps=False)


def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_predictions(model, idata, ["Days", "Subject"], sample_new_groups=True)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_predictions(model, idata, ["Days", "Subject"])


class TestComparison:
Expand Down Expand Up @@ -294,6 +317,25 @@ def test_average_by(self, mtcars, average_by):

# unit level with average by
plot_comparisons(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_comparisons(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_comparisons(
model,
idata,
contrast={"Days": [2, 4]},
conditional={"Subject": [308, 335, 352, 372]},
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_comparisons(model, idata, "Days", "Subject")


class TestSlopes:
Expand Down Expand Up @@ -374,4 +416,23 @@ def test_average_by(self, mtcars, average_by):
plot_slopes(model, idata, "hp", ["am", "drat"], average_by)

# unit level with average by
plot_slopes(model, idata, "hp", None, average_by)
plot_slopes(model, idata, "hp", None, average_by)

def test_group_effects(self, sleep_study):
model, idata = sleep_study

# contains new unseen data
plot_slopes(model, idata, "Days", "Subject", sample_new_groups=True)
# user passed values seen in observed data
plot_slopes(
model,
idata,
wrt={"Days": 2},
conditional={"Subject": 308}
)

with pytest.raises(
ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False."
):
# default: sample_new_groups=False
plot_slopes(model, idata, "Days", "Subject")

0 comments on commit abed003

Please sign in to comment.