Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability to use safe priors for hierarchical models #331

Merged
merged 34 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
130ecc6
modified hssm.py to add special case for HDDM
digicosmos86 Nov 28, 2023
1401861
plugged in a few functions from bambi to handle default priors
digicosmos86 Nov 28, 2023
06795f3
added logic to modify the class to add default priors
digicosmos86 Nov 28, 2023
0dc986a
moved merge_dict to param.py to avoid circular import
digicosmos86 Nov 28, 2023
08a92e1
ensures that tests pass
digicosmos86 Nov 28, 2023
2a4587c
update software versions
digicosmos86 Nov 28, 2023
ce3e067
adjust how bounds are passed
digicosmos86 Nov 30, 2023
c3eeb14
add ruff ignore item, place warnings
digicosmos86 Nov 30, 2023
76a8411
update ci workflow
digicosmos86 Dec 5, 2023
2725e05
exclude ddm_sdv and ddm_full
digicosmos86 Dec 5, 2023
5136e04
fixed minor bugs in param.py
digicosmos86 Dec 5, 2023
24ab25e
use deepcopy to avoid errors
digicosmos86 Dec 5, 2023
c8d1981
added tests for safe prior strategy
digicosmos86 Dec 5, 2023
68cacd5
suppress jax warning
digicosmos86 Dec 5, 2023
01ba4dc
specify float type for each test file
digicosmos86 Dec 5, 2023
bcda159
update hssm version
digicosmos86 Dec 5, 2023
2567e32
Updated default parameter specifications
digicosmos86 Dec 7, 2023
84f1d58
suppress some warnings
digicosmos86 Dec 7, 2023
7a50254
bump ssm-simulators version
digicosmos86 Dec 7, 2023
6c935d0
update ssm-simulators
digicosmos86 Dec 7, 2023
d9b2827
update ssm-simulators
digicosmos86 Dec 7, 2023
f457631
fix a test
digicosmos86 Dec 7, 2023
8d006a0
Merge branch 'safe-prior-strategy' of https://github.com/lnccbrown/HS…
digicosmos86 Dec 7, 2023
509c0e2
set default init to
digicosmos86 Dec 7, 2023
2cb9608
bump ssm-simulators
digicosmos86 Dec 12, 2023
e7da626
Merge branch 'safe-prior-strategy' into update-documentation-020
digicosmos86 Dec 13, 2023
e0bac72
added string representation for generalized logit
digicosmos86 Dec 13, 2023
67575a0
fixed a bug where link_settings does not work in hssm
digicosmos86 Dec 13, 2023
6ae87ad
added documentation for GPU support
digicosmos86 Dec 13, 2023
820ff00
fix bugs in param.py
digicosmos86 Dec 13, 2023
e605503
added documentation for hierachical modeling
digicosmos86 Dec 15, 2023
0d4c507
added changelog
digicosmos86 Dec 15, 2023
5061f44
Merge branch 'main' into safe-prior-strategy
digicosmos86 Dec 15, 2023
32b6af1
changed version to 0.2.0b1
digicosmos86 Dec 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
pytensor = "<=2.17.3"
pytensor = "<2.17.4"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand Down Expand Up @@ -69,7 +69,7 @@ profile = "black"

[tool.ruff]
line-length = 88
target-version = "py39"
target-version = "py310"
unfixable = ["E711"]

select = [
Expand Down Expand Up @@ -132,6 +132,8 @@ ignore = [
"B020",
# Function definition does not bind loop variable
"B023",
# zip()` without an explicit `strict=
"B905",
# Functions defined inside a loop must not use variables redefined in the loop
# "B301", # not yet implemented
# Too many arguments to function call
Expand Down Expand Up @@ -166,14 +168,7 @@ ignore = [
"TID252",
]

exclude = [
".github",
"docs",
"notebook",
"tests",
"src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c",
"src/hssm/likelihoods/hddm_wfpt/wfpt.cpp",
]
exclude = [".github", "docs", "notebook", "tests"]

[tool.ruff.pydocstyle]
convention = "numpy"
Expand Down
34 changes: 31 additions & 3 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import seaborn as sns
import xarray as xr
from bambi.model_components import DistributionalComponent
from bambi.transformations import transformations_namespace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is that one for actually?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a default namespace that has to be included. This can be found in Bambi source code here

https://github.com/bambinos/bambi/blob/312afa24b25385f5fee9e0331e88052598c39b59/bambi/models.py#L149-L155


from hssm.defaults import (
LoglikKind,
Expand Down Expand Up @@ -164,6 +165,9 @@ class HSSM:
recommended when you are using hierarchical models.
The default value is `None` when `hierarchical` is `False` and `"safe"` when
`hierarchical` is `True`.
extra_namespace : optional
Additional user supplied variables with transformations or data to include in
the environment where the formula is evaluated. Defaults to `None`.
**kwargs
Additional arguments passed to the `bmb.Model` object.

Expand Down Expand Up @@ -214,6 +218,7 @@ def __init__(
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
extra_namespace: dict[str, Any] | None = None,
**kwargs,
):
self.data = data
Expand All @@ -232,6 +237,11 @@ def __init__(
self.link_settings = link_settings
self.prior_settings = prior_settings

additional_namespace = transformations_namespace.copy()
if extra_namespace is not None:
additional_namespace.update(extra_namespace)
self.additional_namespace = additional_namespace

responses = self.data["response"].unique().astype(int)
self.n_responses = len(responses)
if self.n_responses == 2:
Expand Down Expand Up @@ -312,7 +322,12 @@ def __init__(
)

self.model = bmb.Model(
self.formula, data, family=self.family, priors=self.priors, **other_kwargs
self.formula,
data=data,
family=self.family,
priors=self.priors,
extra_namespace=extra_namespace,
**other_kwargs,
)

self._aliases = get_alias_dict(self.model, self._parent_param)
Expand Down Expand Up @@ -852,6 +867,8 @@ def _add_kwargs_and_p_outlier_to_include(
"""Process kwargs and p_outlier and add them to include."""
if include is None:
include = []
else:
include = include.copy()
params_in_include = [param["name"] for param in include]

# Process kwargs
Expand Down Expand Up @@ -913,7 +930,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
bounds = self.model_config.bounds.get(param_str)
param = Param(
param_str,
formula="1 + (1|participant_id)",
formula=f"{param_str} ~ 1 + (1|participant_id)",
bounds=bounds,
)
else:
Expand Down Expand Up @@ -956,15 +973,26 @@ def _find_parent(self) -> tuple[str, Param]:

def _override_defaults(self):
"""Override the default priors or links."""
is_ddm = (
self.model_name in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical"
) or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox")
for param in self.list_params:
param_obj = self.params[param]
if self.prior_settings == "safe":
param_obj.override_default_priors(self.data)
if is_ddm:
param_obj.override_default_priors_ddm(
self.data, self.additional_namespace
)
else:
param_obj.override_default_priors(
self.data, self.additional_namespace
)
elif self.link_settings == "log_logit":
param_obj.override_default_link()

def _process_all(self):
"""Process all params."""
assert self.list_params is not None
for param in self.list_params:
self.params[param].convert()

Expand Down
160 changes: 147 additions & 13 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""The Param utility class."""

import logging
from typing import Any, Union, cast
from copy import deepcopy
from typing import Any, Literal, Union, cast

import bambi as bmb
import numpy as np
import pandas as pd
from formulae import design_matrices

from .link import Link
from .prior import Prior
from .prior import Prior, get_default_prior, get_hddm_default_prior

# PEP604 union operator "|" not supported by pylint
# Fall back to old syntax
Expand Down Expand Up @@ -98,14 +100,7 @@ def override_default_link(self):

This is most likely because both default prior and default bounds are supplied.
"""
if self._is_converted:
raise ValueError(
(
"Cannot override the default link function for parameter %s."
+ " The object has already been processed."
)
% self.name,
)
self._ensure_not_converted(context="link")

if not self.is_regression or self._link_specified:
return # do nothing
Expand Down Expand Up @@ -136,8 +131,62 @@ def override_default_link(self):
upper,
)

def override_default_priors(self, data: pd.DataFrame):
"""Override the default priors.
def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
"""Override the default priors - the general case.

By supplying priors for all parameters in the regression, we can override the
defaults that Bambi uses.

Parameters
----------
data
The data used to fit the model.
eval_env
The environment used to evaluate the formula.
"""
self._ensure_not_converted(context="prior")

if not self.is_regression:
return

override_priors = {}
dm = self._get_design_matrices(data, eval_env)

has_common_intercept = False
for name, term in dm.common.terms.items():
if term.kind == "intercept":
has_common_intercept = True
override_priors[name] = get_default_prior(
"common_intercept", self.bounds
)
else:
override_priors[name] = get_default_prior("common", bounds=None)

for name, term in dm.group.terms.items():
if term.kind == "intercept":
if has_common_intercept:
override_priors[name] = get_default_prior("group_intercept", None)
else:
# treat the term as any other group-specific term
_logger.warning(
f"No common intercept. Bounds for parameter {self.name} is not"
+ " applied due to a current limitation of Bambi."
+ " This will change in the future."
)
override_priors[name] = get_default_prior(
"group_specific", bounds=None
)
else:
override_priors[name] = get_default_prior("group_specific", bounds=None)

if not self.prior:
self.prior = override_priors
else:
prior = cast(dict[str, ParamSpec], self.prior)
self.prior = merge_dicts(override_priors, prior)

def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, Any]):
"""Override the default priors - the ddm case.

By supplying priors for all parameters in the regression, we can override the
defaults that Bambi uses.
Expand All @@ -146,8 +195,82 @@ def override_default_priors(self, data: pd.DataFrame):
----------
data
The data used to fit the model.
eval_env
The environment used to evaluate the formula.
"""
self._ensure_not_converted(context="prior")
assert self.name is not None

if not self.is_regression:
return

override_priors = {}
dm = self._get_design_matrices(data, eval_env)

has_common_intercept = False
for name, term in dm.common.terms.items():
if term.kind == "intercept":
has_common_intercept = True
override_priors[name] = get_hddm_default_prior(
"common_intercept", self.name, self.bounds
)
else:
override_priors[name] = get_hddm_default_prior(
"common", self.name, bounds=None
)

for name, term in dm.group.terms.items():
if term.kind == "intercept":
if has_common_intercept:
override_priors[name] = get_hddm_default_prior(
"group_intercept", self.name, bounds=None
)
else:
# treat the term as any other group-specific term
_logger.warning(
f"No common intercept. Bounds for parameter {self.name} is not"
+ " applied due to a current limitation of Bambi."
+ " This will change in the future."
)
override_priors[name] = get_hddm_default_prior(
"group_intercept", self.name, bounds=None
)
else:
override_priors[name] = get_hddm_default_prior(
"group_specific", self.name, bounds=None
)

if not self.prior:
self.prior = override_priors
else:
prior = cast(dict[str, ParamSpec], self.prior)
self.prior = merge_dicts(override_priors, prior)

def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]):
"""Get the design matrices for the regression.

Parameters
----------
data
A pandas DataFrame
eval_env
The evaluation environment
"""
return # Will implement in the next PR
formula = cast(str, self.formula)
rhs = formula.split("~")[1]
formula = "rt ~ " + rhs
dm = design_matrices(formula, data=data, extra_namespace=extra_namespace)

return dm

def _ensure_not_converted(self, context=Literal["link", "prior"]):
"""Ensure that the object has not been converted."""
if self._is_converted:
context = "link function" if context == "link" else "priors"
raise ValueError(
f"Cannot override the default {context} for parameter {self.name}."
+ " The object has already been processed."
)

def set_parent(self):
"""Set the Param as parent."""
Expand Down Expand Up @@ -531,3 +654,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior:
return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0)
else:
return bmb.Prior(name="Uniform", lower=lower, upper=upper)


def merge_dicts(dict1: dict, dict2: dict) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strictly speaking this is not a merge right?

seems like if key in merged but instances are not dicts, you override the value of dict1 with whatever you find in dict2.

So maybe something like override_dict ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually raises a question: should the override be recursive? For example, if 1|participant_id has a defined prior, with only sigma having a defined hyperprior, should we also add our default mu? It seems that if we don't then a bambi or pymc default will be applied

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say no, because we would in fact not cover the case in all generality even then.
The logic would only apply if the user defined distribution 1|participant_id matches the one for which we have hyper-priors available.

If the user defines their own prior, they should define it all the way down I would say.

"""Recursively merge two dictionaries."""
merged = deepcopy(dict1)
for key, value in dict2.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
Loading