From abed0030445c77eb664d110d6dabaa66e94c3912 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 15 Sep 2023 14:02:09 +0200 Subject: [PATCH] add group specific effects for interpret plotting functions --- tests/test_plots.py | 63 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/tests/test_plots.py b/tests/test_plots.py index 8191e50eb..8ba082df4 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -11,6 +11,7 @@ @pytest.fixture(scope="module") def mtcars(): + "Model with common level effects only" data = bmb.load_data('mtcars') data["am"] = pd.Categorical(data["am"], categories=[0, 1], ordered=True) model = bmb.Model("mpg ~ hp * drat * am", data) @@ -18,6 +19,15 @@ def mtcars(): return model, idata +@pytest.fixture(scope="module") +def sleep_study(): + "Model with common and group specific effects" + data = bmb.load_data('sleepstudy') + model = bmb.Model("Reaction ~ 1 + Days + (Days | Subject)", data) + idata = model.fit(tune=500, draws=500, random_seed=1234) + return model, idata + + # Improvement: # * Test the actual plots are what we are indeed the desired result. # * Test using the dictionary and the list gives the same plot @@ -224,6 +234,19 @@ def test_multiple_outputs_with_alias(self, pps): # Test user supplied target argument plot_predictions(model, idata, "x", "alpha", pps=False) + + + def test_group_effects(self, sleep_study): + model, idata = sleep_study + + # contains new unseen data + plot_predictions(model, idata, ["Days", "Subject"], sample_new_groups=True) + + with pytest.raises( + ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False." + ): + # default: sample_new_groups=False + plot_predictions(model, idata, ["Days", "Subject"]) class TestComparison: @@ -294,6 +317,25 @@ def test_average_by(self, mtcars, average_by): # unit level with average by plot_comparisons(model, idata, "hp", None, average_by) + + def test_group_effects(self, sleep_study): + model, idata = sleep_study + + # contains new unseen data + plot_comparisons(model, idata, "Days", "Subject", sample_new_groups=True) + # user passed values seen in observed data + plot_comparisons( + model, + idata, + contrast={"Days": [2, 4]}, + conditional={"Subject": [308, 335, 352, 372]}, + ) + + with pytest.raises( + ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False." + ): + # default: sample_new_groups=False + plot_comparisons(model, idata, "Days", "Subject") class TestSlopes: @@ -374,4 +416,23 @@ def test_average_by(self, mtcars, average_by): plot_slopes(model, idata, "hp", ["am", "drat"], average_by) # unit level with average by - plot_slopes(model, idata, "hp", None, average_by) \ No newline at end of file + plot_slopes(model, idata, "hp", None, average_by) + + def test_group_effects(self, sleep_study): + model, idata = sleep_study + + # contains new unseen data + plot_slopes(model, idata, "Days", "Subject", sample_new_groups=True) + # user passed values seen in observed data + plot_slopes( + model, + idata, + wrt={"Days": 2}, + conditional={"Subject": 308} + ) + + with pytest.raises( + ValueError, match="There are new groups for the factors \('Subject',\) and 'sample_new_groups' is False." + ): + # default: sample_new_groups=False + plot_slopes(model, idata, "Days", "Subject") \ No newline at end of file