From e2805074fffe5ecd25806304eaadb2707fd61465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 21 Sep 2023 13:51:06 +0200 Subject: [PATCH 01/13] Added missing doc string part for center_predictors Perhaps the rationale behind centering could be added to help user decide whether to turn it on or not. But I'm not entirely sure about this. --- bambi/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bambi/models.py b/bambi/models.py index a1bf37c8b..8a9a85c64 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -86,6 +86,10 @@ class Model: noncentered : bool If ``True`` (default), uses a non-centered parameterization for normal hyperpriors on grouped parameters. If ``False``, naive (centered) parameterization is used. + center_predictors : bool + If ``True`` (default), and if there is an intercept in the common terms, the data is + centered by subtracting the mean. The centering is undone after sampling to provide + the actual intercept in all distributional components that have an intercept. extra_namespace : dict, optional Additional user supplied variables with transformations or data to include in the environment where the formula is evaluated. Defaults to `None`. From d40a1d84094f2abba264d713564ea9cd51ea9824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Fri, 22 Sep 2023 13:05:44 +0200 Subject: [PATCH 02/13] Clarified that centering affects prior interpretation in models.py --- bambi/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bambi/models.py b/bambi/models.py index 8a9a85c64..b2afb687e 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -89,7 +89,9 @@ class Model: center_predictors : bool If ``True`` (default), and if there is an intercept in the common terms, the data is centered by subtracting the mean. The centering is undone after sampling to provide - the actual intercept in all distributional components that have an intercept. + the actual intercept in all distributional components that have an intercept. Note + that this changes the interpretation of the prior on the intercept because it refers + to the intercept of the centered data. extra_namespace : dict, optional Additional user supplied variables with transformations or data to include in the environment where the formula is evaluated. Defaults to `None`. From d6a6b7c9dbb4147ea64d5ee26ed1d3c7f2b95029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Tue, 26 Sep 2023 12:03:43 +0200 Subject: [PATCH 03/13] Applied black to make CI happy. Problem must have been some invisible character, or so. Also deleted superflous space unrelated to the black's complaint --- bambi/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/models.py b/bambi/models.py index 93a5aaff5..860ff77dc 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -89,7 +89,7 @@ class Model: center_predictors : bool If ``True`` (default), and if there is an intercept in the common terms, the data is centered by subtracting the mean. The centering is undone after sampling to provide - the actual intercept in all distributional components that have an intercept. Note + the actual intercept in all distributional components that have an intercept. Note that this changes the interpretation of the prior on the intercept because it refers to the intercept of the centered data. extra_namespace : dict, optional From a072c84809fe62e6efe1d7818f3d5d28d81e00ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Tue, 26 Sep 2023 12:19:51 +0200 Subject: [PATCH 04/13] Added --diff and --color to black in test.yml ... for more informative output. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b0479537..403d5be77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,7 @@ jobs: shell: bash -l {0} run: | echo "Running black..." - black bambi --check + black bambi --check --diff --color echo "Checking code style with pylint..." pylint bambi From b79673a78673ea50ac8e624d9bba5edbd66b60c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 12:12:59 +0200 Subject: [PATCH 05/13] Fix for plotting order error --- bambi/interpret/plotting.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 13e97d3ea..ab0a19872 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -322,6 +322,10 @@ def plot_comparisons( "Must specify a covariate to 'average_by' when number of covariates" "passed to 'conditional' is greater than 3." ) + if isinstance(conditional, dict): + for k,v in conditional: + conditional[k] = sroted(v) + if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " From 314a4dc708e785bb4c43aae107cb18391aba5094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 12:43:43 +0200 Subject: [PATCH 06/13] Still fixing plot order --- bambi/interpret/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index ab0a19872..54e0ed836 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -323,7 +323,7 @@ def plot_comparisons( "passed to 'conditional' is greater than 3." ) if isinstance(conditional, dict): - for k,v in conditional: + for k, v in conditional.items(): conditional[k] = sroted(v) if average_by is True: From 4217a49ecc951df3364db729c61de4852614e789 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 12:45:32 +0200 Subject: [PATCH 07/13] Another typo Cannot make changes locally at the moment. Hence so many commits ... --- bambi/interpret/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 54e0ed836..05530b26f 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -324,7 +324,7 @@ def plot_comparisons( ) if isinstance(conditional, dict): for k, v in conditional.items(): - conditional[k] = sroted(v) + conditional[k] = sorted(v) if average_by is True: raise ValueError( From 137464b8fc3a4dc5e0c58394401d20dfa25674b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 12:59:39 +0200 Subject: [PATCH 08/13] Black found some spaces ... --- bambi/interpret/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 05530b26f..637090f84 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -324,8 +324,8 @@ def plot_comparisons( ) if isinstance(conditional, dict): for k, v in conditional.items(): - conditional[k] = sorted(v) - + conditional[k] = sort + if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " From b62cc893653e18a0df7adb68e7ac4e41310a9393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 13:02:22 +0200 Subject: [PATCH 09/13] Arrg, more typos It's a bit annoying, I can only work from the web interface currently, but hopefully all typos are resolved now --- bambi/interpret/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 637090f84..0e505f2e2 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -324,7 +324,7 @@ def plot_comparisons( ) if isinstance(conditional, dict): for k, v in conditional.items(): - conditional[k] = sort + conditional[k] = sorted(v) if average_by is True: raise ValueError( From da5085ef93903f7231eb17bccb9493487d24295b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20T=C3=BCnnermann?= Date: Thu, 28 Sep 2023 13:09:07 +0200 Subject: [PATCH 10/13] Making pylint happy --- bambi/interpret/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 0e505f2e2..544de3caf 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -323,8 +323,8 @@ def plot_comparisons( "passed to 'conditional' is greater than 3." ) if isinstance(conditional, dict): - for k, v in conditional.items(): - conditional[k] = sorted(v) + for key, value in conditional.items(): + conditional[key] = sorted(value) if average_by is True: raise ValueError( From 7bdd001128241c2781dfef70046a8dd431d37448 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 3 Oct 2023 06:26:39 +0200 Subject: [PATCH 11/13] add .drop_duplicates() before returning dataframe --- bambi/interpret/create_data.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bambi/interpret/create_data.py b/bambi/interpret/create_data.py index b532a3545..688874aea 100644 --- a/bambi/interpret/create_data.py +++ b/bambi/interpret/create_data.py @@ -70,6 +70,11 @@ def _grid_level( elif kind == "slopes": pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid, variable_info.name) + # After computing default values, fractional values may have been computed. + # Enforcing the dtype of "int" may create duplicate rows as it will round + # the fractional values. + pairwise_grid = pairwise_grid.drop_duplicates() + return pairwise_grid From 7ca9762f784e183fd742cf4fdcd2d33687cbf2b1 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 3 Oct 2023 06:28:04 +0200 Subject: [PATCH 12/13] use np.unique instead of get_unique_levels func (and remove) --- bambi/interpret/plot_types.py | 16 ++++++++-------- bambi/interpret/utils.py | 13 +------------ 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/bambi/interpret/plot_types.py b/bambi/interpret/plot_types.py index 43c273414..276676d55 100644 --- a/bambi/interpret/plot_types.py +++ b/bambi/interpret/plot_types.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from bambi.interpret.utils import Covariates, get_unique_levels, get_group_offset, identity +from bambi.interpret.utils import Covariates, get_group_offset, identity def plot_numeric( @@ -49,7 +49,7 @@ def plot_numeric( ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4) elif "group" in covariates and not "panel" in covariates: ax = axes[0] - colors = get_unique_levels(plot_data[color]) + colors = np.unique(plot_data[color]) for i, clr in enumerate(colors): idx = (plot_data[color] == clr).to_numpy() values_main = transform_main(plot_data.loc[idx, main]) @@ -62,7 +62,7 @@ def plot_numeric( color=f"C{i}", ) elif not "group" in covariates and "panel" in covariates: - panels = get_unique_levels(plot_data[panel]) + panels = np.unique(plot_data[panel]) for ax, pnl in zip(axes.ravel(), panels): idx = (plot_data[panel] == pnl).to_numpy() values_main = transform_main(plot_data.loc[idx, main]) @@ -70,8 +70,8 @@ def plot_numeric( ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4) ax.set(title=f"{panel} = {pnl}") elif "group" in covariates and "panel" in covariates: - colors = get_unique_levels(plot_data[color]) - panels = get_unique_levels(plot_data[panel]) + colors = np.unique(plot_data[color]) + panels = np.unique(plot_data[panel]) if color == panel: for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)): idx = (plot_data[panel] == pnl).to_numpy() @@ -138,20 +138,20 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool main, color, panel = covariates.main, covariates.group, covariates.panel covariates = {k: v for k, v in vars(covariates).items() if v is not None} - main_levels = get_unique_levels(plot_data[main]) + main_levels = np.unique(plot_data[main]) main_levels_n = len(main_levels) idxs_main = np.arange(main_levels_n) y_hat_mean = plot_data["estimate"] y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values) if "group" in covariates: - colors = get_unique_levels(plot_data[color]) + colors = np.unique(plot_data[color]) colors_n = len(colors) offset_bounds = get_group_offset(colors_n) colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n) if "panel" in covariates: - panels = get_unique_levels(plot_data[panel]) + panels = np.unique(plot_data[panel]) if len(covariates) == 1: ax = axes[0] diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index 85ec9bce2..67fa8a204 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -124,7 +124,7 @@ def set_default_variable_values(self) -> np.ndarray: predictor_data, self.eps, dtype ) elif component.kind == "categoric": - values = get_unique_levels(predictor_data) + values = np.unique(predictor_data) return values @@ -377,17 +377,6 @@ def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray: raise ValueError("Group covariate must be numeric or categoric.") -def get_unique_levels(x: np.ndarray) -> np.ndarray: - """ - Get unique levels of a categoric variable. - """ - if hasattr(x, "dtype") and hasattr(x.dtype, "categories"): - levels = np.array((x.dtype.categories)) - else: - levels = np.unique(x) - return levels - - def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray: # Complementary log log function, scaled. # See following code to have an idea of how this function looks like From cdd40df85d4b897ce8bf8c76cbf761bdc0a3cac9 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 3 Oct 2023 06:28:54 +0200 Subject: [PATCH 13/13] sort values if conditional dict in plot_comparisons and plot_slopes --- bambi/interpret/plotting.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 544de3caf..3dfaaa572 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -323,8 +323,7 @@ def plot_comparisons( "passed to 'conditional' is greater than 3." ) if isinstance(conditional, dict): - for key, value in conditional.items(): - conditional[key] = sorted(value) + conditional = {key: sorted(value) for key, value in conditional.items()} if average_by is True: raise ValueError( @@ -473,6 +472,7 @@ def plot_slopes( if conditional is None and average_by is None: raise ValueError("Must specify at least one of 'conditional' or 'average_by'.") + if conditional is not None: if not isinstance(conditional, str): if len(conditional) > 3 and average_by is None: @@ -480,6 +480,8 @@ def plot_slopes( "Must specify a covariate to 'average_by' when number of covariates" "passed to 'conditional' is greater than 3." ) + if isinstance(conditional, dict): + conditional = {key: sorted(value) for key, value in conditional.items()} if average_by is True: raise ValueError(