-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the ability to use safe priors for hierarchical models #331
Changes from 8 commits
130ecc6
1401861
06795f3
0dc986a
08a92e1
2a4587c
ce3e067
c3eeb14
76a8411
2725e05
5136e04
24ab25e
c8d1981
68cacd5
01ba4dc
bcda159
2567e32
84f1d58
7a50254
6c935d0
d9b2827
f457631
8d006a0
509c0e2
2cb9608
e7da626
e0bac72
67575a0
6ae87ad
820ff00
e605503
0d4c507
5061f44
32b6af1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
"""The Param utility class.""" | ||
|
||
import logging | ||
from typing import Any, Union, cast | ||
from copy import deepcopy | ||
from typing import Any, Literal, Union, cast | ||
|
||
import bambi as bmb | ||
import numpy as np | ||
import pandas as pd | ||
from formulae import design_matrices | ||
|
||
from .link import Link | ||
from .prior import Prior | ||
from .prior import Prior, get_default_prior, get_hddm_default_prior | ||
|
||
# PEP604 union operator "|" not supported by pylint | ||
# Fall back to old syntax | ||
|
@@ -98,14 +100,7 @@ def override_default_link(self): | |
|
||
This is most likely because both default prior and default bounds are supplied. | ||
""" | ||
if self._is_converted: | ||
raise ValueError( | ||
( | ||
"Cannot override the default link function for parameter %s." | ||
+ " The object has already been processed." | ||
) | ||
% self.name, | ||
) | ||
self._ensure_not_converted(context="link") | ||
|
||
if not self.is_regression or self._link_specified: | ||
return # do nothing | ||
|
@@ -136,8 +131,62 @@ def override_default_link(self): | |
upper, | ||
) | ||
|
||
def override_default_priors(self, data: pd.DataFrame): | ||
"""Override the default priors. | ||
def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): | ||
"""Override the default priors - the general case. | ||
|
||
By supplying priors for all parameters in the regression, we can override the | ||
defaults that Bambi uses. | ||
|
||
Parameters | ||
---------- | ||
data | ||
The data used to fit the model. | ||
eval_env | ||
The environment used to evaluate the formula. | ||
""" | ||
self._ensure_not_converted(context="prior") | ||
|
||
if not self.is_regression: | ||
return | ||
|
||
override_priors = {} | ||
dm = self._get_design_matrices(data, eval_env) | ||
|
||
has_common_intercept = False | ||
for name, term in dm.common.terms.items(): | ||
if term.kind == "intercept": | ||
has_common_intercept = True | ||
override_priors[name] = get_default_prior( | ||
"common_intercept", self.bounds | ||
) | ||
else: | ||
override_priors[name] = get_default_prior("common", bounds=None) | ||
|
||
for name, term in dm.group.terms.items(): | ||
if term.kind == "intercept": | ||
if has_common_intercept: | ||
override_priors[name] = get_default_prior("group_intercept", None) | ||
else: | ||
# treat the term as any other group-specific term | ||
_logger.warning( | ||
f"No common intercept. Bounds for parameter {self.name} is not" | ||
+ " applied due to a current limitation of Bambi." | ||
+ " This will change in the future." | ||
) | ||
override_priors[name] = get_default_prior( | ||
"group_specific", bounds=None | ||
) | ||
else: | ||
override_priors[name] = get_default_prior("group_specific", bounds=None) | ||
|
||
if not self.prior: | ||
self.prior = override_priors | ||
else: | ||
prior = cast(dict[str, ParamSpec], self.prior) | ||
self.prior = merge_dicts(override_priors, prior) | ||
|
||
def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, Any]): | ||
"""Override the default priors - the ddm case. | ||
|
||
By supplying priors for all parameters in the regression, we can override the | ||
defaults that Bambi uses. | ||
|
@@ -146,8 +195,82 @@ def override_default_priors(self, data: pd.DataFrame): | |
---------- | ||
data | ||
The data used to fit the model. | ||
eval_env | ||
The environment used to evaluate the formula. | ||
""" | ||
self._ensure_not_converted(context="prior") | ||
assert self.name is not None | ||
|
||
if not self.is_regression: | ||
return | ||
|
||
override_priors = {} | ||
dm = self._get_design_matrices(data, eval_env) | ||
|
||
has_common_intercept = False | ||
for name, term in dm.common.terms.items(): | ||
if term.kind == "intercept": | ||
has_common_intercept = True | ||
override_priors[name] = get_hddm_default_prior( | ||
"common_intercept", self.name, self.bounds | ||
) | ||
else: | ||
override_priors[name] = get_hddm_default_prior( | ||
"common", self.name, bounds=None | ||
) | ||
|
||
for name, term in dm.group.terms.items(): | ||
if term.kind == "intercept": | ||
if has_common_intercept: | ||
override_priors[name] = get_hddm_default_prior( | ||
"group_intercept", self.name, bounds=None | ||
) | ||
else: | ||
# treat the term as any other group-specific term | ||
_logger.warning( | ||
f"No common intercept. Bounds for parameter {self.name} is not" | ||
+ " applied due to a current limitation of Bambi." | ||
+ " This will change in the future." | ||
) | ||
override_priors[name] = get_hddm_default_prior( | ||
"group_intercept", self.name, bounds=None | ||
) | ||
else: | ||
override_priors[name] = get_hddm_default_prior( | ||
"group_specific", self.name, bounds=None | ||
) | ||
|
||
if not self.prior: | ||
self.prior = override_priors | ||
else: | ||
prior = cast(dict[str, ParamSpec], self.prior) | ||
self.prior = merge_dicts(override_priors, prior) | ||
|
||
def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]): | ||
"""Get the design matrices for the regression. | ||
|
||
Parameters | ||
---------- | ||
data | ||
A pandas DataFrame | ||
eval_env | ||
The evaluation environment | ||
""" | ||
return # Will implement in the next PR | ||
formula = cast(str, self.formula) | ||
rhs = formula.split("~")[1] | ||
formula = "rt ~ " + rhs | ||
dm = design_matrices(formula, data=data, extra_namespace=extra_namespace) | ||
|
||
return dm | ||
|
||
def _ensure_not_converted(self, context=Literal["link", "prior"]): | ||
"""Ensure that the object has not been converted.""" | ||
if self._is_converted: | ||
context = "link function" if context == "link" else "priors" | ||
raise ValueError( | ||
f"Cannot override the default {context} for parameter {self.name}." | ||
+ " The object has already been processed." | ||
) | ||
|
||
def set_parent(self): | ||
"""Set the Param as parent.""" | ||
|
@@ -531,3 +654,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior: | |
return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0) | ||
else: | ||
return bmb.Prior(name="Uniform", lower=lower, upper=upper) | ||
|
||
|
||
def merge_dicts(dict1: dict, dict2: dict) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. strictly speaking this is not a merge right? seems like if So maybe something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually raises a question: should the override be recursive? For example, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say no, because we would in fact not cover the case in all generality even then. If the user defines their own prior, they should define it all the way down I would say. |
||
"""Recursively merge two dictionaries.""" | ||
merged = deepcopy(dict1) | ||
for key, value in dict2.items(): | ||
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): | ||
merged[key] = merge_dicts(merged[key], value) | ||
else: | ||
merged[key] = value | ||
return merged |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is that one for actually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a default namespace that has to be included. This can be found in Bambi source code here
https://github.com/bambinos/bambi/blob/312afa24b25385f5fee9e0331e88052598c39b59/bambi/models.py#L149-L155