From 130ecc6ac9809cbe92c14dd4b677fd9b13d50d67 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 10:34:33 -0500 Subject: [PATCH 01/31] modified hssm.py to add special case for HDDM --- src/hssm/hssm.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index d17b2782..34eef856 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -24,6 +24,7 @@ import seaborn as sns import xarray as xr from bambi.model_components import DistributionalComponent +from bambi.transformations import transformations_namespace from hssm.defaults import ( LoglikKind, @@ -164,6 +165,15 @@ class HSSM: recommended when you are using hierarchical models. The default value is `None` when `hierarchical` is `False` and `"safe"` when `hierarchical` is `True`. + center_predictors : optional + If `True`, and if there is an intercept in the common terms, the + data is centered by subtracting the mean. The centering is undone after sampling + to provide the actual intercept in all distributional components that have an + intercept. Note that this changes the interpretation of the prior on the + intercept because it refers to the intercept of the centered data. + extra_namespace : optional + Additional user supplied variables with transformations or data to include in + the environment where the formula is evaluated. Defaults to `None`. **kwargs Additional arguments passed to the `bmb.Model` object. @@ -214,6 +224,8 @@ def __init__( hierarchical: bool = False, link_settings: Literal["log_logit"] | None = None, prior_settings: Literal["safe"] | None = None, + center_predictors: bool = False, + extra_namespace: dict[str, Any] | None = None, **kwargs, ): self.data = data @@ -232,6 +244,11 @@ def __init__( self.link_settings = link_settings self.prior_settings = prior_settings + additional_namespace = transformations_namespace.copy() + if extra_namespace is not None: + additional_namespace.update(extra_namespace) + self.additional_namespace = additional_namespace + responses = self.data["response"].unique().astype(int) self.n_responses = len(responses) if self.n_responses == 2: @@ -312,7 +329,13 @@ def __init__( ) self.model = bmb.Model( - self.formula, data, family=self.family, priors=self.priors, **other_kwargs + self.formula, + data=data, + family=self.family, + priors=self.priors, + center_predictors=center_predictors, + extra_namespace=extra_namespace, + **other_kwargs, ) self._aliases = get_alias_dict(self.model, self._parent_param) @@ -852,6 +875,8 @@ def _add_kwargs_and_p_outlier_to_include( """Process kwargs and p_outlier and add them to include.""" if include is None: include = [] + else: + include = include.copy() params_in_include = [param["name"] for param in include] # Process kwargs @@ -956,15 +981,26 @@ def _find_parent(self) -> tuple[str, Param]: def _override_defaults(self): """Override the default priors or links.""" + is_ddm = ( + self.model in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical" + ) or (self.model == "ddm_full" and self.loglik_kind == "blackbox") for param in self.list_params: param_obj = self.params[param] if self.prior_settings == "safe": - param_obj.override_default_priors(self.data) + if is_ddm: + param_obj.override_default_priors_ddm( + self.data, self.additional_namespace + ) + else: + param_obj.override_default_priors( + self.data, self.additional_namespace + ) elif self.link_settings == "log_logit": param_obj.override_default_link() def _process_all(self): """Process all params.""" + assert self.list_params is not None for param in self.list_params: self.params[param].convert() From 1401861983dc00882e66cd16cd202d0287997ff2 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 10:35:48 -0500 Subject: [PATCH 02/31] plugged in a few functions from bambi to handle default priors --- src/hssm/prior.py | 159 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 1 deletion(-) diff --git a/src/hssm/prior.py b/src/hssm/prior.py index 3cc21244..2edb9da9 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -9,7 +9,8 @@ 3. The ability to shorten the output of bmb.Prior. """ -from typing import Callable +from statistics import mean +from typing import Any, Callable import bambi as bmb import numpy as np @@ -142,3 +143,159 @@ def TruncatedDist(name): ) return TruncatedDist + + +def generate_prior( + dist: str | dict | int | float, bounds: tuple[float, float] | None = None, **kwargs +): + """Generate a Prior distribution. + + The parameter ``kwargs`` is used to pass hyperpriors that are assigned to the + parameters of the prior to be built. + + This function is taken from bambi.priors.prior.py and modified to handle bounds. + + Parameters + ---------- + dist: + If a string, it is the name of the prior distribution with default values taken + from ``SETTINGS_DISTRIBUTIONS``. If a number, it is a factor used to scale the + standard deviation of the priors generated automatically by Bambi. If a `dict`, + it must contain a ``"dist"`` key with the name of the distribution and other + keys. + bounds: optional + A tuple of two floats indicating the lower and upper bounds of the prior. + + Raises + ------ + ValueError + If ``dist`` is not a string, number, or dict. + + Returns + ------- + Prior + The Prior instance. + """ + if isinstance(dist, str): + default_settings = HSSM_SETTINGS_DISTRIBUTIONS[dist] + if kwargs: + for k, v in kwargs.items(): + default_settings[k] = generate_prior(v) + prior: Prior | int | float = Prior(dist, bounds=bounds, **default_settings) + elif isinstance(dist, dict): + prior_settings = dist.copy() + dist = prior_settings.pop("dist") + for k, v in prior_settings.items(): + prior_settings[k] = generate_prior(v) + prior = generate_prior(dist, bounds=bounds, **prior_settings) + elif isinstance(dist, (int, float)): + if bounds is not None: + lower, upper = bounds + if dist < lower or dist > upper: + raise ValueError( + f"The prior value {dist} is outside the bounds {bounds}." + ) + prior = dist + else: + raise ValueError( + "'dist' must be the name of a distribution or a numeric value." + ) + return prior + + +def get_default_prior(term_type: str, bounds: tuple[float, float] | None): + """Generate a Prior based on the default settings. + + The following summarizes default priors for each type of term: + + * common_intercept: Bounded Normal prior (N(mean(bounds), 0.25)). + * common: Normal prior (N(0, 0.25)). + * group_intercept: Normal prior where its sigma has a HalfFlat hyperprior. + * group_specific: Normal prior where its sigma has a HalfNormal hyperprior. + + This function is taken from bambi.priors.prior.py and modified to handle hssm- + specific situations. + + Parameters + ---------- + term_type : str + The type of the term for which the default prior is wanted. + bounds : tuple[float, float] | None + A tuple of two floats indicating the lower and upper bounds of the prior. + + Raises + ------ + ValueError + If ``term_type`` is not within the values listed above. + + Returns + ------- + prior: Prior + The instance of Prior according to the ``term_type``. + """ + if term_type == "common": + prior = generate_prior("Normal", bounds=None) + elif term_type == "common_intercept": + if bounds is not None: + if any(np.isinf(b) for b in bounds): + prior = generate_prior("Normal", bounds=bounds) + else: + prior = generate_prior( + "Normal", mu=mean(bounds), sigma=0.25, bounds=bounds + ) + else: + prior = generate_prior("Normal") + elif term_type == "group_intercept": + prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=bounds) + elif term_type == "group_specific": + prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None) + else: + raise ValueError("Unrecognized term type.") + return prior + + +def get_hddm_default_prior( + term_type: str, param: str, bounds: tuple[float, float] | None +): + """Generate a Prior based on the default settings - the HDDM case.""" + if term_type == "common": + prior = generate_prior("Normal", bounds=None) + elif term_type == "common_intercept": + prior = generate_prior(HDDM_MU[param], bounds=bounds) + elif term_type == "group_intercept": + prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=bounds) + elif term_type == "group_specific": + prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None) + else: + raise ValueError("Unrecognized term type.") + return prior + + +HSSM_SETTINGS_DISTRIBUTIONS: dict[Any, Any] = { + "Normal": {"mu": 0.0, "sigma": 0.25}, + "Weibull": {"alpha": 1.5, "beta": 0.3}, +} + +HDDM_MU: dict[Any, Any] = { + "v": {"dist": "Normal", "mu": 2.0, "sigma": 3.0}, + "a": {"dist": "Gamma", "alpha": 1.5, "beta": 0.75}, + "z": {"dist": "Normal", "mu": 0.5, "sigma": 0.5}, + "t": {"dist": "Gamma", "alpha": 0.4, "beta": 0.2}, +} + +HDDM_SIGMA: dict[Any, Any] = { + "v": {"dist": "HalfNormal", "sigma": 2.0}, + "a": {"dist": "HalfNormal", "sigma": 0.1}, + "z": {"dist": "HalfNormal", "sigma": 0.05}, + "t": {"dist": "HalfNormal", "sigma": 1.0}, + "sv": {"dist": "HalfNormal", "sigma": 2.0}, + "sz": {"dist": "Beta", "alpha": 1.0, "beta": 3.0}, + "st": {"dist": "HalfNormal", "sigma": 0.3}, +} + +HDDM_SETTINGS_GROUP: dict[Any, Any] = { + "v": {"dist": "Normal", "mu": HDDM_MU["v"], "sigma": HDDM_SIGMA["v"]}, + "a": {"dist": "Gamma", "alpha": HDDM_MU["a"], "beta": HDDM_SIGMA["a"]}, + "z": {"dist": "Beta", "alpha": HDDM_MU["z"], "beta": HDDM_SIGMA["z"]}, + "t": {"dist": "Normal", "mu": HDDM_MU["t"], "sigma": HDDM_SIGMA["t"]}, +} From 06795f3b2cb79cad781136ab65e6cda11e1d04cd Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 10:36:10 -0500 Subject: [PATCH 03/31] added logic to modify the class to add default priors --- src/hssm/param.py | 141 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 129 insertions(+), 12 deletions(-) diff --git a/src/hssm/param.py b/src/hssm/param.py index 0dfde27c..606d7342 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -6,9 +6,11 @@ import bambi as bmb import numpy as np import pandas as pd +from formulae import design_matrices from .link import Link -from .prior import Prior +from .prior import Prior, get_default_prior, get_hddm_default_prior +from .utils import merge_dicts # PEP604 union operator "|" not supported by pylint # Fall back to old syntax @@ -98,14 +100,7 @@ def override_default_link(self): This is most likely because both default prior and default bounds are supplied. """ - if self._is_converted: - raise ValueError( - ( - "Cannot override the default link function for parameter %s." - + " The object has already been processed." - ) - % self.name, - ) + self._ensure_not_converted() if not self.is_regression or self._link_specified: return # do nothing @@ -136,8 +131,8 @@ def override_default_link(self): upper, ) - def override_default_priors(self, data: pd.DataFrame): - """Override the default priors. + def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): + """Override the default priors - the general case. By supplying priors for all parameters in the regression, we can override the defaults that Bambi uses. @@ -146,8 +141,130 @@ def override_default_priors(self, data: pd.DataFrame): ---------- data The data used to fit the model. + eval_env + The environment used to evaluate the formula. """ - return # Will implement in the next PR + self._ensure_not_converted() + + if not self.is_regression: + return + + override_priors = {} + dm = self._get_design_matrices(data, eval_env) + + has_common_intercept = False + for name, term in dm.common.terms.items(): + if term.kind == "intercept": + has_common_intercept = True + override_priors[name] = get_default_prior( + "common_intercept", self.bounds + ) + else: + override_priors[name] = get_default_prior("common", bounds=None) + + for name, term in dm.group.terms.items(): + if term.kind == "intercept": + if has_common_intercept: + override_priors[name] = get_default_prior( + "group_intercept", self.bounds + ) + else: + # treat the term as any other group-specific term + override_priors[name] = get_default_prior( + "group_specific", bounds=None + ) + else: + override_priors[name] = get_default_prior("group_specific", bounds=None) + + if not self.prior: + self.prior = override_priors + else: + prior = cast(dict[str, ParamSpec], self.prior) + self.prior = merge_dicts(override_priors, prior) + + def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, Any]): + """Override the default priors - the ddm case. + + By supplying priors for all parameters in the regression, we can override the + defaults that Bambi uses. + + Parameters + ---------- + data + The data used to fit the model. + eval_env + The environment used to evaluate the formula. + """ + self._ensure_not_converted() + assert self.name is not None + + if not self.is_regression: + return + + override_priors = {} + dm = self._get_design_matrices(data, eval_env) + + has_common_intercept = False + for name, term in dm.common.terms.items(): + if term.kind == "intercept": + has_common_intercept = True + override_priors[name] = get_hddm_default_prior( + "common_intercept", self.name, self.bounds + ) + else: + override_priors[name] = get_hddm_default_prior( + "common", self.name, bounds=None + ) + + for name, term in dm.group.terms.items(): + if term.kind == "intercept": + if has_common_intercept: + override_priors[name] = get_hddm_default_prior( + "group_intercept", self.name, bounds=self.bounds + ) + else: + # treat the term as any other group-specific term + override_priors[name] = get_hddm_default_prior( + "group_intercept", self.name, bounds=None + ) + else: + override_priors[name] = get_hddm_default_prior( + "group_specific", self.name, bounds=None + ) + + if not self.prior: + self.prior = override_priors + else: + prior = cast(dict[str, ParamSpec], self.prior) + self.prior = merge_dicts(override_priors, prior) + + def _get_design_matrices(self, data: pd.DataFrame, eval_env: dict[str, Any]): + """Get the design matrices for the regression. + + Parameters + ---------- + data + A pandas DataFrame + eval_env + The evaluation environment + """ + formula = cast(str, self.formula) + rhs = formula.split("~")[1] + formula = "rt ~ " + rhs + dm = design_matrices(formula, data=data, eval_env=eval_env) + + return dm + + def _ensure_not_converted(self): + """Ensure that the object has not been converted.""" + if self._is_converted: + raise ValueError( + ( + "Cannot override the default priors for parameter %s." + + " The object has already been processed." + ) + % self.name, + ) def set_parent(self): """Set the Param as parent.""" From 0dc986a17f24b873998122189642a2d790adb0f0 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 10:42:55 -0500 Subject: [PATCH 04/31] moved merge_dict to param.py to avoid circular import --- src/hssm/param.py | 13 ++++++++++++- src/hssm/utils.py | 12 ------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/hssm/param.py b/src/hssm/param.py index 606d7342..9a4b95d7 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -6,11 +6,11 @@ import bambi as bmb import numpy as np import pandas as pd +from deepcopy import deepcopy from formulae import design_matrices from .link import Link from .prior import Prior, get_default_prior, get_hddm_default_prior -from .utils import merge_dicts # PEP604 union operator "|" not supported by pylint # Fall back to old syntax @@ -648,3 +648,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior: return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0) else: return bmb.Prior(name="Uniform", lower=lower, upper=upper) + + +def merge_dicts(dict1: dict, dict2: dict) -> dict: + """Recursively merge two dictionaries.""" + merged = deepcopy(dict1) + for key, value in dict2.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = merge_dicts(merged[key], value) + else: + merged[key] = value + return merged diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 2adcc1e6..5534aa28 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -10,7 +10,6 @@ """ import logging -from copy import deepcopy from typing import Any, Iterable, Literal, NewType import bambi as bmb @@ -54,17 +53,6 @@ def download_hf(path: str): return hf_hub_download(repo_id=REPO_ID, filename=path) -def merge_dicts(dict1: dict, dict2: dict) -> dict: - """Recursively merge two dictionaries.""" - merged = deepcopy(dict1) - for key, value in dict2.items(): - if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): - merged[key] = merge_dicts(merged[key], value) - else: - merged[key] = value - return merged - - def make_alias_dict_from_parent(parent: Param) -> dict[str, str]: """Make aliases from the parent parameter. From 08a92e13813fad814f389fb81503c0e26de7cef3 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 12:00:52 -0500 Subject: [PATCH 05/31] ensures that tests pass --- src/hssm/hssm.py | 14 +++----------- src/hssm/param.py | 24 +++++++++++------------- src/hssm/prior.py | 19 +++++++++++++------ 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 34eef856..10d9eaaa 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -165,12 +165,6 @@ class HSSM: recommended when you are using hierarchical models. The default value is `None` when `hierarchical` is `False` and `"safe"` when `hierarchical` is `True`. - center_predictors : optional - If `True`, and if there is an intercept in the common terms, the - data is centered by subtracting the mean. The centering is undone after sampling - to provide the actual intercept in all distributional components that have an - intercept. Note that this changes the interpretation of the prior on the - intercept because it refers to the intercept of the centered data. extra_namespace : optional Additional user supplied variables with transformations or data to include in the environment where the formula is evaluated. Defaults to `None`. @@ -224,7 +218,6 @@ def __init__( hierarchical: bool = False, link_settings: Literal["log_logit"] | None = None, prior_settings: Literal["safe"] | None = None, - center_predictors: bool = False, extra_namespace: dict[str, Any] | None = None, **kwargs, ): @@ -333,7 +326,6 @@ def __init__( data=data, family=self.family, priors=self.priors, - center_predictors=center_predictors, extra_namespace=extra_namespace, **other_kwargs, ) @@ -938,7 +930,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]: bounds = self.model_config.bounds.get(param_str) param = Param( param_str, - formula="1 + (1|participant_id)", + formula=f"{param_str} ~ 1 + (1|participant_id)", bounds=bounds, ) else: @@ -982,8 +974,8 @@ def _find_parent(self) -> tuple[str, Param]: def _override_defaults(self): """Override the default priors or links.""" is_ddm = ( - self.model in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical" - ) or (self.model == "ddm_full" and self.loglik_kind == "blackbox") + self.model_name in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical" + ) or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox") for param in self.list_params: param_obj = self.params[param] if self.prior_settings == "safe": diff --git a/src/hssm/param.py b/src/hssm/param.py index 9a4b95d7..2f2309a9 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -1,12 +1,12 @@ """The Param utility class.""" import logging -from typing import Any, Union, cast +from copy import deepcopy +from typing import Any, Literal, Union, cast import bambi as bmb import numpy as np import pandas as pd -from deepcopy import deepcopy from formulae import design_matrices from .link import Link @@ -100,7 +100,7 @@ def override_default_link(self): This is most likely because both default prior and default bounds are supplied. """ - self._ensure_not_converted() + self._ensure_not_converted(context="link") if not self.is_regression or self._link_specified: return # do nothing @@ -144,7 +144,7 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): eval_env The environment used to evaluate the formula. """ - self._ensure_not_converted() + self._ensure_not_converted(context="prior") if not self.is_regression: return @@ -195,7 +195,7 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An eval_env The environment used to evaluate the formula. """ - self._ensure_not_converted() + self._ensure_not_converted(context="prior") assert self.name is not None if not self.is_regression: @@ -238,7 +238,7 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An prior = cast(dict[str, ParamSpec], self.prior) self.prior = merge_dicts(override_priors, prior) - def _get_design_matrices(self, data: pd.DataFrame, eval_env: dict[str, Any]): + def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]): """Get the design matrices for the regression. Parameters @@ -251,19 +251,17 @@ def _get_design_matrices(self, data: pd.DataFrame, eval_env: dict[str, Any]): formula = cast(str, self.formula) rhs = formula.split("~")[1] formula = "rt ~ " + rhs - dm = design_matrices(formula, data=data, eval_env=eval_env) + dm = design_matrices(formula, data=data, extra_namespace=extra_namespace) return dm - def _ensure_not_converted(self): + def _ensure_not_converted(self, context=Literal["link", "prior"]): """Ensure that the object has not been converted.""" if self._is_converted: + context = "link function" if context == "link" else "priors" raise ValueError( - ( - "Cannot override the default priors for parameter %s." - + " The object has already been processed." - ) - % self.name, + f"Cannot override the default {context} for parameter {self.name}." + + " The object has already been processed." ) def set_parent(self): diff --git a/src/hssm/prior.py b/src/hssm/prior.py index 2edb9da9..0b9409b4 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -146,7 +146,9 @@ def TruncatedDist(name): def generate_prior( - dist: str | dict | int | float, bounds: tuple[float, float] | None = None, **kwargs + dist: str | dict | int | float | Prior, + bounds: tuple[float, float] | None = None, + **kwargs, ): """Generate a Prior distribution. @@ -184,10 +186,12 @@ def generate_prior( prior: Prior | int | float = Prior(dist, bounds=bounds, **default_settings) elif isinstance(dist, dict): prior_settings = dist.copy() - dist = prior_settings.pop("dist") + dist_name: str = prior_settings.pop("dist") for k, v in prior_settings.items(): prior_settings[k] = generate_prior(v) - prior = generate_prior(dist, bounds=bounds, **prior_settings) + prior = Prior(dist_name, bounds=bounds, **prior_settings) + elif isinstance(dist, Prior): + prior = dist elif isinstance(dist, (int, float)): if bounds is not None: lower, upper = bounds @@ -246,9 +250,9 @@ def get_default_prior(term_type: str, bounds: tuple[float, float] | None): else: prior = generate_prior("Normal") elif term_type == "group_intercept": - prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=bounds) + prior = generate_prior("Normal", mu="Normal", sigma="Weibull") elif term_type == "group_specific": - prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None) + prior = generate_prior("Normal", mu="Normal", sigma="Weibull") else: raise ValueError("Unrecognized term type.") return prior @@ -263,7 +267,7 @@ def get_hddm_default_prior( elif term_type == "common_intercept": prior = generate_prior(HDDM_MU[param], bounds=bounds) elif term_type == "group_intercept": - prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=bounds) + prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=None) elif term_type == "group_specific": prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None) else: @@ -274,6 +278,9 @@ def get_hddm_default_prior( HSSM_SETTINGS_DISTRIBUTIONS: dict[Any, Any] = { "Normal": {"mu": 0.0, "sigma": 0.25}, "Weibull": {"alpha": 1.5, "beta": 0.3}, + "HalfNormal": {"sigma": 0.25}, + "Beta": {"alpha": 1.0, "beta": 1.0}, + "Gamma": {"k": 1.0, "theta": 1.0}, } HDDM_MU: dict[Any, Any] = { From 2a4587c4c72beafb4efea7639fb5a17bbe4c0f33 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 28 Nov 2023 16:45:34 -0500 Subject: [PATCH 06/31] update software versions --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 631d5826..0200e92f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ bambi = "^0.12.0" numpyro = "^0.12.1" hddm-wfpt = "^0.1.1" seaborn = "^0.13.0" -pytensor = "<=2.17.3" +pytensor = "<2.17.4" [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" From ce3e067f65281bff1b08283ae258787afffa3fe6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 30 Nov 2023 11:30:37 -0500 Subject: [PATCH 07/31] adjust how bounds are passed --- pyproject.toml | 11 ++--------- src/hssm/param.py | 10 ++++------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0200e92f..eccafae3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ profile = "black" [tool.ruff] line-length = 88 -target-version = "py39" +target-version = "py310" unfixable = ["E711"] select = [ @@ -166,14 +166,7 @@ ignore = [ "TID252", ] -exclude = [ - ".github", - "docs", - "notebook", - "tests", - "src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c", - "src/hssm/likelihoods/hddm_wfpt/wfpt.cpp", -] +exclude = [".github", "docs", "notebook", "tests"] [tool.ruff.pydocstyle] convention = "numpy" diff --git a/src/hssm/param.py b/src/hssm/param.py index 2f2309a9..e8e4195d 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -165,13 +165,11 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): for name, term in dm.group.terms.items(): if term.kind == "intercept": if has_common_intercept: - override_priors[name] = get_default_prior( - "group_intercept", self.bounds - ) + override_priors[name] = get_default_prior("group_intercept", None) else: # treat the term as any other group-specific term override_priors[name] = get_default_prior( - "group_specific", bounds=None + "group_specific", bounds=self.bounds ) else: override_priors[name] = get_default_prior("group_specific", bounds=None) @@ -220,12 +218,12 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An if term.kind == "intercept": if has_common_intercept: override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=self.bounds + "group_intercept", self.name, bounds=None ) else: # treat the term as any other group-specific term override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=None + "group_intercept", self.name, bounds=self.bounds ) else: override_priors[name] = get_hddm_default_prior( From c3eeb1405d52421233b7e3fb67eeb1d419e12c8d Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 30 Nov 2023 11:37:16 -0500 Subject: [PATCH 08/31] add ruff ignore item, place warnings --- pyproject.toml | 2 ++ src/hssm/param.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eccafae3..0031c008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,8 @@ ignore = [ "B020", # Function definition does not bind loop variable "B023", + # zip()` without an explicit `strict= + "B905", # Functions defined inside a loop must not use variables redefined in the loop # "B301", # not yet implemented # Too many arguments to function call diff --git a/src/hssm/param.py b/src/hssm/param.py index e8e4195d..5b78d2b4 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -168,8 +168,13 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): override_priors[name] = get_default_prior("group_intercept", None) else: # treat the term as any other group-specific term + _logger.warning( + f"No common intercept. Bounds for parameter {self.name} is not" + + " applied due to a current limitation of Bambi." + + " This will change in the future." + ) override_priors[name] = get_default_prior( - "group_specific", bounds=self.bounds + "group_specific", bounds=None ) else: override_priors[name] = get_default_prior("group_specific", bounds=None) @@ -222,8 +227,13 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An ) else: # treat the term as any other group-specific term + _logger.warning( + f"No common intercept. Bounds for parameter {self.name} is not" + + " applied due to a current limitation of Bambi." + + " This will change in the future." + ) override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=self.bounds + "group_intercept", self.name, bounds=None ) else: override_priors[name] = get_hddm_default_prior( From 76a8411cb10b708c05ee19e1fa3875696744e6fe Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:52:20 -0500 Subject: [PATCH 09/31] update ci workflow --- .github/workflows/run_tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index dc3e33ab..71b70fe5 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -2,7 +2,6 @@ name: Run tests on: pull_request: - push: jobs: run_tests: From 2725e056fbc774f03b4043f76bb16408c3472237 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:52:49 -0500 Subject: [PATCH 10/31] exclude ddm_sdv and ddm_full --- src/hssm/hssm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 10d9eaaa..80b7bc7d 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -974,8 +974,8 @@ def _find_parent(self) -> tuple[str, Param]: def _override_defaults(self): """Override the default priors or links.""" is_ddm = ( - self.model_name in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical" - ) or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox") + self.model_name == "ddm" and self.loglik_kind == "analytical" + ) # or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox") for param in self.list_params: param_obj = self.params[param] if self.prior_settings == "safe": From 5136e0493d2c35ae77fdbc1d0f3263d25020a47b Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:53:17 -0500 Subject: [PATCH 11/31] fixed minor bugs in param.py --- src/hssm/param.py | 104 +++++++++++++++++++++++++--------------------- 1 file changed, 56 insertions(+), 48 deletions(-) diff --git a/src/hssm/param.py b/src/hssm/param.py index 5b78d2b4..eac3ca99 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -153,37 +153,43 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): dm = self._get_design_matrices(data, eval_env) has_common_intercept = False - for name, term in dm.common.terms.items(): - if term.kind == "intercept": - has_common_intercept = True - override_priors[name] = get_default_prior( - "common_intercept", self.bounds - ) - else: - override_priors[name] = get_default_prior("common", bounds=None) - - for name, term in dm.group.terms.items(): - if term.kind == "intercept": - if has_common_intercept: - override_priors[name] = get_default_prior("group_intercept", None) - else: - # treat the term as any other group-specific term - _logger.warning( - f"No common intercept. Bounds for parameter {self.name} is not" - + " applied due to a current limitation of Bambi." - + " This will change in the future." + if dm.common is not None: + for name, term in dm.common.terms.items(): + if term.kind == "intercept": + has_common_intercept = True + override_priors[name] = get_default_prior( + "common_intercept", self.bounds ) + else: + override_priors[name] = get_default_prior("common", bounds=None) + + if dm.group is not None: + for name, term in dm.group.terms.items(): + if term.kind == "intercept": + if has_common_intercept: + override_priors[name] = get_default_prior( + "group_intercept", None + ) + else: + # treat the term as any other group-specific term + _logger.warning( + f"No common intercept. Bounds for parameter {self.name} is" + + " not applied due to a current limitation of Bambi." + + " This will change in the future." + ) + override_priors[name] = get_default_prior( + "group_specific", bounds=None + ) + else: override_priors[name] = get_default_prior( "group_specific", bounds=None ) - else: - override_priors[name] = get_default_prior("group_specific", bounds=None) if not self.prior: self.prior = override_priors else: prior = cast(dict[str, ParamSpec], self.prior) - self.prior = merge_dicts(override_priors, prior) + self.prior = override_priors | prior def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, Any]): """Override the default priors - the ddm case. @@ -208,43 +214,45 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An dm = self._get_design_matrices(data, eval_env) has_common_intercept = False - for name, term in dm.common.terms.items(): - if term.kind == "intercept": - has_common_intercept = True - override_priors[name] = get_hddm_default_prior( - "common_intercept", self.name, self.bounds - ) - else: - override_priors[name] = get_hddm_default_prior( - "common", self.name, bounds=None - ) - - for name, term in dm.group.terms.items(): - if term.kind == "intercept": - if has_common_intercept: + if dm.common is not None: + for name, term in dm.common.terms.items(): + if term.kind == "intercept": + has_common_intercept = True override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=None + "common_intercept", self.name, self.bounds ) else: - # treat the term as any other group-specific term - _logger.warning( - f"No common intercept. Bounds for parameter {self.name} is not" - + " applied due to a current limitation of Bambi." - + " This will change in the future." + override_priors[name] = get_hddm_default_prior( + "common", self.name, bounds=None ) + + if dm.group is not None: + for name, term in dm.group.terms.items(): + if term.kind == "intercept": + if has_common_intercept: + override_priors[name] = get_hddm_default_prior( + "group_intercept", self.name, bounds=None + ) + else: + # treat the term as any other group-specific term + _logger.warning( + f"No common intercept. Bounds for parameter {self.name} is" + + " not applied due to a current limitation of Bambi." + + " This will change in the future." + ) + override_priors[name] = get_hddm_default_prior( + "group_intercept", self.name, bounds=None + ) + else: override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=None + "group_specific", self.name, bounds=None ) - else: - override_priors[name] = get_hddm_default_prior( - "group_specific", self.name, bounds=None - ) if not self.prior: self.prior = override_priors else: prior = cast(dict[str, ParamSpec], self.prior) - self.prior = merge_dicts(override_priors, prior) + self.prior = override_priors | prior def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]): """Get the design matrices for the regression. From 24ab25e3508f1c77e3bc63d187fe6dd5370ecaab Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:53:34 -0500 Subject: [PATCH 12/31] use deepcopy to avoid errors --- src/hssm/prior.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/hssm/prior.py b/src/hssm/prior.py index 0b9409b4..eb9a028b 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -8,7 +8,7 @@ 2. The ability to still print out the prior before the truncation. 3. The ability to shorten the output of bmb.Prior. """ - +from copy import deepcopy from statistics import mean from typing import Any, Callable @@ -179,13 +179,13 @@ def generate_prior( The Prior instance. """ if isinstance(dist, str): - default_settings = HSSM_SETTINGS_DISTRIBUTIONS[dist] + default_settings = deepcopy(HSSM_SETTINGS_DISTRIBUTIONS[dist]) if kwargs: for k, v in kwargs.items(): default_settings[k] = generate_prior(v) prior: Prior | int | float = Prior(dist, bounds=bounds, **default_settings) elif isinstance(dist, dict): - prior_settings = dist.copy() + prior_settings = deepcopy(dist) dist_name: str = prior_settings.pop("dist") for k, v in prior_settings.items(): prior_settings[k] = generate_prior(v) @@ -214,8 +214,9 @@ def get_default_prior(term_type: str, bounds: tuple[float, float] | None): * common_intercept: Bounded Normal prior (N(mean(bounds), 0.25)). * common: Normal prior (N(0, 0.25)). - * group_intercept: Normal prior where its sigma has a HalfFlat hyperprior. - * group_specific: Normal prior where its sigma has a HalfNormal hyperprior. + * group_intercept: Normal prior N(N(0, 0.25), Weibull(1.5, 0.3). It's supposed to + be bounded but Bambi does not fully support it yet. + * group_specific: Normal prior N(N(0, 0.25), Weibull(1.5, 0.3). This function is taken from bambi.priors.prior.py and modified to handle hssm- specific situations. @@ -242,6 +243,7 @@ def get_default_prior(term_type: str, bounds: tuple[float, float] | None): elif term_type == "common_intercept": if bounds is not None: if any(np.isinf(b) for b in bounds): + # TODO: Make it more specific. prior = generate_prior("Normal", bounds=bounds) else: prior = generate_prior( @@ -280,7 +282,7 @@ def get_hddm_default_prior( "Weibull": {"alpha": 1.5, "beta": 0.3}, "HalfNormal": {"sigma": 0.25}, "Beta": {"alpha": 1.0, "beta": 1.0}, - "Gamma": {"k": 1.0, "theta": 1.0}, + "Gamma": {"alpha": 1.0, "beta": 1.0}, } HDDM_MU: dict[Any, Any] = { @@ -295,9 +297,9 @@ def get_hddm_default_prior( "a": {"dist": "HalfNormal", "sigma": 0.1}, "z": {"dist": "HalfNormal", "sigma": 0.05}, "t": {"dist": "HalfNormal", "sigma": 1.0}, - "sv": {"dist": "HalfNormal", "sigma": 2.0}, - "sz": {"dist": "Beta", "alpha": 1.0, "beta": 3.0}, - "st": {"dist": "HalfNormal", "sigma": 0.3}, + # "sv": {"dist": "HalfNormal", "sigma": 2.0}, + # "sz": {"dist": "Beta", "alpha": 1.0, "beta": 3.0}, + # "st": {"dist": "HalfNormal", "sigma": 0.3}, } HDDM_SETTINGS_GROUP: dict[Any, Any] = { From c8d1981581c675a5958b4f8c9d62d08cc947e379 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:53:54 -0500 Subject: [PATCH 13/31] added tests for safe prior strategy --- tests/test_param.py | 307 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) diff --git a/tests/test_param.py b/tests/test_param.py index 8b43588c..745ace5b 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -9,6 +9,9 @@ _make_priors_recursive, _make_bounded_prior, ) +from hssm.defaults import default_model_config + +hssm.set_floatX("float64") def test_param_creation_non_regression(): @@ -417,3 +420,307 @@ def test_param_override_default_link(caplog, formula, link, bounds, result): with pytest.raises(ValueError): param.override_default_link() + + +def _check_group_prior(group_prior): + assert isinstance(group_prior, bmb.Prior) + assert group_prior.dist is None + assert group_prior.name == "Normal" + + mu = group_prior.args["mu"] + sigma = group_prior.args["sigma"] + + assert isinstance(group_prior, bmb.Prior) + assert mu.name == "Normal" + assert mu.args["mu"] == 0.0 + assert mu.args["sigma"] == 0.25 + + assert isinstance(group_prior, bmb.Prior) + assert sigma.name == "Weibull" + assert sigma.args["alpha"] == 1.5 + assert sigma.args["beta"] == 0.3 + + +angle_params = default_model_config["angle"]["list_params"] +angle_bounds = default_model_config["angle"]["likelihoods"]["approx_differentiable"][ + "bounds" +].values() +param_and_bounds = zip(angle_params, angle_bounds) + + +@pytest.mark.parametrize( + ("param_name", "bounds"), + param_and_bounds, +) +def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds): + # Shouldn't do anything if the param is not a regression + param_non_reg = Param( + name=param_name, + prior={}, + ) + param_non_reg.override_default_priors(cavanagh_test, {}) + assert not param_non_reg.prior + + # The basic regression case, no group-specific terms + param = Param( + name=param_name, + formula=f"{param_name} ~ 1 + theta", + bounds=bounds, + ) + + param.override_default_priors(cavanagh_test, {}) + + assert param.prior is not None + + intercept_prior = param.prior["Intercept"] + slope_prior = param.prior["theta"] + + assert isinstance(intercept_prior, hssm.Prior) + assert intercept_prior.is_truncated + assert intercept_prior.bounds == bounds + assert intercept_prior.dist is not None + lower, upper = intercept_prior.bounds + _mu = intercept_prior._args["mu"] + if isinstance(_mu, np.ndarray): + assert _mu.item() == (lower + upper) / 2 + else: + assert _mu == (lower + upper) / 2 + assert intercept_prior._args["sigma"] == 0.25 + + assert isinstance(slope_prior, bmb.Prior) + assert slope_prior.dist is None + assert slope_prior.args["mu"] == 0.0 + assert slope_prior.args["sigma"] == 0.25 + + unif_prior = {"name": "Uniform", "lower": 0.0, "upper": 1.0} + set_prior = { + "Intercept": unif_prior, + "theta": unif_prior, + } + + param_with_prior = Param( + name=param_name, + formula=f"{param_name} ~ 1 + theta", + bounds=bounds, + prior=set_prior, + ) + + param_with_prior.override_default_priors(cavanagh_test, {}) + assert param_with_prior.prior == set_prior + + # The regression case, with group-specific terms + param_group = Param( + name=param_name, + formula=f"{param_name} ~ 1 + (1 + theta | participant_id)", + bounds=bounds, + ) + + param_group.override_default_priors(cavanagh_test, {}) + + assert all( + param in param_group.prior + for param in ["Intercept", "1|participant_id", "theta|participant_id"] + ) + + assert param_group.prior["Intercept"].is_truncated + + group_intercept_prior = param_group.prior["1|participant_id"] + group_slope_prior = param_group.prior["theta|participant_id"] + + _check_group_prior(group_intercept_prior) + _check_group_prior(group_slope_prior) + + param_no_common_intercept = Param( + name=param_name, + formula=f"{param_name} ~ 0 + (1 + theta | participant_id)", + bounds=bounds, + ) + + param_no_common_intercept.override_default_priors(cavanagh_test, {}) + assert "limitation" in caplog.records[0].msg + + assert "Intercept" not in param_no_common_intercept.prior + group_intercept_prior = param_group.prior["1|participant_id"] + group_slope_prior = param_group.prior["theta|participant_id"] + + _check_group_prior(group_intercept_prior) + _check_group_prior(group_slope_prior) + + +v_mu = {"name": "Normal", "mu": 2.0, "sigma": 3.0} +v_sigma = {"name": "HalfNormal", "sigma": 2.0} +v_prior = {"name": "Normal", "mu": v_mu, "sigma": v_sigma} + +a_mu = {"name": "Gamma", "alpha": 1.5, "beta": 0.75} +a_sigma = {"name": "HalfNormal", "sigma": 0.1} +a_prior = {"name": "Gamma", "alpha": a_mu, "beta": a_sigma} + +z_mu = {"name": "Normal", "mu": 0.5, "sigma": 0.5} +z_sigma = {"name": "HalfNormal", "sigma": 0.05} +z_prior = {"name": "Beta", "alpha": z_mu, "beta": z_sigma} + +t_mu = {"name": "Gamma", "alpha": 0.4, "beta": 0.2} +t_sigma = {"name": "HalfNormal", "sigma": 1} +t_prior = {"name": "Normal", "mu": t_mu, "sigma": t_sigma} + +sv = {"name": "HalfNormal", "sigma": 2.0} +st = {"name": "HalfNormal", "sigma": 0.3} +sz = {"name": "Beta", "alpha": 1.0, "beta": 1.0} + + +@pytest.mark.parametrize( + ("param_name", "mu", "prior"), + [ + ("v", v_mu, v_prior), + ("a", a_mu, a_prior), + ("z", z_mu, z_prior), + ("t", t_mu, t_prior), + ], +) +def test_param_override_default_priors_ddm( + cavanagh_test, caplog, param_name, mu, prior +): + # Shouldn't do anything if the param is not a regression + param_non_reg = Param( + name=param_name, + prior={}, + ) + param_non_reg.override_default_priors_ddm(cavanagh_test, {}) + assert not param_non_reg.prior + + bounds = (-10, 10) + + # The basic regression case, no group-specific terms + param = Param( + name=param_name, + formula=f"{param_name} ~ 1 + theta", + bounds=bounds, # invalid, just for testing + ) + + param.override_default_priors_ddm(cavanagh_test, {}) + + intercept_prior = param.prior["Intercept"] + slope_prior = param.prior["theta"] + + assert isinstance(intercept_prior, hssm.Prior) + assert intercept_prior.bounds == bounds + assert intercept_prior.dist is not None + mu1 = mu.copy() + assert intercept_prior.name == mu1.pop("name") + for key, val in mu1.items(): + val1 = intercept_prior._args[key] + if isinstance(val, np.ndarray): + val1 = val1.item() + assert val1 == val + + assert isinstance(slope_prior, bmb.Prior) + assert slope_prior.dist is None + assert slope_prior.args["mu"] == 0.0 + assert slope_prior.args["sigma"] == 0.25 + + # If prior is set, do not override + unif_prior = {"name": "Uniform", "lower": 0.0, "upper": 1.0} + set_prior = { + "Intercept": unif_prior, + "theta": unif_prior, + } + + param_with_prior = Param( + name=param_name, + formula=f"{param_name} ~ 1 + theta", + bounds=bounds, + prior=set_prior, + ) + + param_with_prior.override_default_priors_ddm(cavanagh_test, {}) + assert param_with_prior.prior == set_prior + + # The regression case, with group-specific terms + param_group = Param( + name=param_name, + formula=f"{param_name} ~ 1 + (1 + theta | participant_id)", + bounds=bounds, + ) + + param_group.override_default_priors_ddm(cavanagh_test, {}) + + assert all( + param in param_group.prior + for param in ["Intercept", "1|participant_id", "theta|participant_id"] + ) + + assert param_group.prior["Intercept"].is_truncated + + group_intercept_prior = param_group.prior["1|participant_id"] + group_slope_prior = param_group.prior["theta|participant_id"] + + def _check_group_prior_intercept_ddm(group_prior, prior): + assert isinstance(group_prior, bmb.Prior) + assert group_prior.dist is None + prior1 = prior.copy() + assert group_prior.name == prior1.pop("name") + for key, val in prior1.items(): + hyperprior = group_prior.args[key] + val1 = val.copy() + assert hyperprior.name == val1.pop("name") + for key2, val2 in val1.items(): + assert hyperprior.args[key2] == val2 + + _check_group_prior_intercept_ddm(group_intercept_prior, prior) + _check_group_prior(group_slope_prior) + + param_no_common_intercept = Param( + name=param_name, + formula=f"{param_name} ~ 0 + (1 + theta | participant_id)", + bounds=bounds, + ) + + param_no_common_intercept.override_default_priors_ddm(cavanagh_test, {}) + assert "limitation" in caplog.records[0].msg + + assert "Intercept" not in param_no_common_intercept.prior + group_intercept_prior = param_group.prior["1|participant_id"] + group_slope_prior = param_group.prior["theta|participant_id"] + + _check_group_prior_intercept_ddm(group_intercept_prior, prior) + _check_group_prior(group_slope_prior) + + +def test_hssm_override_default_prior(cavanagh_test): + model1 = hssm.HSSM( + model="angle", + data=cavanagh_test, + hierarchical=False, + include=[ + { + "name": "v", + "formula": "v ~ 1 + C(conf)", + } + ], + prior_settings="safe", + ) + + param_v = model1.params["v"] + assert param_v.prior["Intercept"].name == "Normal" + assert param_v.prior["Intercept"].is_truncated + + model2 = hssm.HSSM( + model="ddm", + data=cavanagh_test, + hierarchical=True, + include=[ + { + "name": "v", + "formula": "v ~ 1 + theta", + "prior": {"Intercept": {"name": "Uniform", "lower": -10, "upper": 10}}, + }, + ], + prior_settings="safe", + ) + param_v = model2.params["v"] + assert param_v.prior["Intercept"].name == "Uniform" + assert param_v.prior["theta"].name == "Normal" + + param_a = model2.params["a"] + assert param_a.prior["Intercept"].name == a_mu["name"] + assert "1|participant_id" in param_a.prior From 68cacd5f96380c41f28fed764ba66352d8f9a753 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:56:54 -0500 Subject: [PATCH 14/31] suppress jax warning --- src/hssm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 5534aa28..64f84660 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -18,7 +18,7 @@ import xarray as xr from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm from huggingface_hub import hf_hub_download -from jax.config import config +from jax import config from pymc.model_graph import ModelGraph from pytensor import function From 01ba4dcf79819ab6aa92177f669350a155e7b14c Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:58:22 -0500 Subject: [PATCH 15/31] specify float type for each test file --- tests/test_config.py | 2 ++ tests/test_distribution_utils.py | 2 ++ tests/test_hssm.py | 5 ++--- tests/test_likelihoods.py | 4 ++++ tests/test_onnx.py | 3 ++- tests/test_plotting.py | 2 ++ tests/test_prior.py | 3 +++ tests/test_sample_posterior_predictive.py | 2 ++ tests/test_simulator.py | 3 +++ tests/test_utils.py | 2 ++ 10 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 3cc77e44..b0261bde 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,8 @@ import hssm from hssm.config import Config, ModelConfig +hssm.set_floatX("float32") + def test_from_defaults(): # Case 1: Has default prior diff --git a/tests/test_distribution_utils.py b/tests/test_distribution_utils.py index 8d12970d..f1efc5b1 100644 --- a/tests/test_distribution_utils.py +++ b/tests/test_distribution_utils.py @@ -9,6 +9,8 @@ from hssm.distribution_utils.dist import apply_param_bounds_to_loglik, make_distribution from hssm.likelihoods.analytical import logp_ddm, DDM +hssm.set_floatX("float32") + def test_make_ssm_rv(): params = ["v", "a", "z", "t"] diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 3d890fe9..91ecaa5f 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -2,15 +2,14 @@ import bambi as bmb import numpy as np -import pandas as pd -import pytensor import pytest +import hssm from hssm import HSSM from hssm.utils import download_hf from hssm.likelihoods import DDM, logp_ddm -pytensor.config.floatX = "float32" +hssm.set_floatX("float32") param_v = { "name": "v", diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index c3c457d2..18b7f441 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -9,10 +9,14 @@ import pytest from numpy.random import rand +import hssm + # pylint: disable=C0413 from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox +hssm.set_floatX("float32") + def test_kterm(data_ddm): """This function defines a range of kterms and tests results to diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 31e464b6..5474e7ee 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -7,9 +7,10 @@ import pytensor.tensor as pt import pytest +import hssm from hssm.distribution_utils.onnx import * -pytensor.config.floatX = "float32" +hssm.set_floatX("float32") DECIMAL = 4 diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c9a100fb..93588137 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -23,6 +23,8 @@ plot_posterior_predictive, ) +hssm.set_floatX("float32") + def test__get_title(): assert _get_title(("a"), ("b")) == "a = b" diff --git a/tests/test_prior.py b/tests/test_prior.py index 42d07611..cfdcd1f8 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -3,8 +3,11 @@ import bambi as bmb import numpy as np +import hssm from hssm import Prior +hssm.set_floatX("float32") + def test_truncation(): hssm_prior = Prior("Uniform", lower=0.0, upper=1.0) diff --git a/tests/test_sample_posterior_predictive.py b/tests/test_sample_posterior_predictive.py index d10a8cc9..08ebdb91 100644 --- a/tests/test_sample_posterior_predictive.py +++ b/tests/test_sample_posterior_predictive.py @@ -1,5 +1,7 @@ import hssm +hssm.set_floatX("float32") + def test_sample_posterior_predictive(cav_idata, cavanagh_test): model = hssm.HSSM( diff --git a/tests/test_simulator.py b/tests/test_simulator.py index af7d9605..0eaf1949 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -2,8 +2,11 @@ import pandas as pd import pytest +import hssm from hssm.simulator import simulate_data +hssm.set_floatX("float32") + def test_simulator(): theta = [0.5, 1.5, 0.5, 0.5] diff --git a/tests/test_utils.py b/tests/test_utils.py index 5a004664..af606ed8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,6 +12,8 @@ _random_sample, ) +hssm.set_floatX("float32") + def test_get_alias_dict(): # Simulate some data: From bcda15939fb50c934e7f832bf84436ac4370c5db Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 5 Dec 2023 13:58:35 -0500 Subject: [PATCH 16/31] update hssm version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0031c008..34b3a4ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "HSSM" -version = "0.1.5" +version = "0.2.0" description = "Bayesian inference for hierarchical sequential sampling models." authors = [ "Alexander Fengler ", From 2567e32cee115a4be6a4c1eb74ab8cec166c5cc6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 11:23:22 -0500 Subject: [PATCH 17/31] Updated default parameter specifications --- src/hssm/hssm.py | 5 +++-- src/hssm/prior.py | 24 +++++++++++++++--------- tests/test_param.py | 35 +++++++++++++++++++---------------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 80b7bc7d..0e2ef27e 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -974,8 +974,9 @@ def _find_parent(self) -> tuple[str, Param]: def _override_defaults(self): """Override the default priors or links.""" is_ddm = ( - self.model_name == "ddm" and self.loglik_kind == "analytical" - ) # or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox") + self.model_name in ["ddm", "ddm_sdv", "ddm_full"] + and self.loglik_kind != "approx_differentiable" + ) for param in self.list_params: param_obj = self.params[param] if self.prior_settings == "safe": diff --git a/src/hssm/prior.py b/src/hssm/prior.py index eb9a028b..4f44d6ac 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -282,29 +282,35 @@ def get_hddm_default_prior( "Weibull": {"alpha": 1.5, "beta": 0.3}, "HalfNormal": {"sigma": 0.25}, "Beta": {"alpha": 1.0, "beta": 1.0}, - "Gamma": {"alpha": 1.0, "beta": 1.0}, + "Gamma": {"mu": 1.0, "sigma": 1.0}, } HDDM_MU: dict[Any, Any] = { "v": {"dist": "Normal", "mu": 2.0, "sigma": 3.0}, - "a": {"dist": "Gamma", "alpha": 1.5, "beta": 0.75}, - "z": {"dist": "Normal", "mu": 0.5, "sigma": 0.5}, - "t": {"dist": "Gamma", "alpha": 0.4, "beta": 0.2}, + "a": {"dist": "Gamma", "mu": 1.5, "sigma": 0.75}, + "z": {"dist": "Gamma", "mu": 10, "sigma": 10}, + "t": {"dist": "Gamma", "mu": 0.4, "sigma": 0.2}, + "sv": {"dist": "HalfNormal", "sigma": 2.0}, + "st": {"dist": "HalfNormal", "sigma": 0.3}, + "sz": {"dist": "HalfNormal", "sigma": 0.5}, } HDDM_SIGMA: dict[Any, Any] = { "v": {"dist": "HalfNormal", "sigma": 2.0}, "a": {"dist": "HalfNormal", "sigma": 0.1}, - "z": {"dist": "HalfNormal", "sigma": 0.05}, + "z": {"dist": "Gamma", "mu": 10, "sigma": 10}, "t": {"dist": "HalfNormal", "sigma": 1.0}, - # "sv": {"dist": "HalfNormal", "sigma": 2.0}, - # "sz": {"dist": "Beta", "alpha": 1.0, "beta": 3.0}, - # "st": {"dist": "HalfNormal", "sigma": 0.3}, + "sv": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"}, + "sz": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"}, + "st": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"}, } HDDM_SETTINGS_GROUP: dict[Any, Any] = { "v": {"dist": "Normal", "mu": HDDM_MU["v"], "sigma": HDDM_SIGMA["v"]}, - "a": {"dist": "Gamma", "alpha": HDDM_MU["a"], "beta": HDDM_SIGMA["a"]}, + "a": {"dist": "Gamma", "mu": HDDM_MU["a"], "sigma": HDDM_SIGMA["a"]}, "z": {"dist": "Beta", "alpha": HDDM_MU["z"], "beta": HDDM_SIGMA["z"]}, "t": {"dist": "Normal", "mu": HDDM_MU["t"], "sigma": HDDM_SIGMA["t"]}, + "sv": {"dist": "Gamma", "mu": HDDM_MU["sv"], "sigma": HDDM_SIGMA["sv"]}, + "sz": {"dist": "Gamma", "mu": HDDM_MU["sz"], "sigma": HDDM_SIGMA["sz"]}, + "st": {"dist": "Gamma", "mu": HDDM_MU["st"], "sigma": HDDM_SIGMA["st"]}, } diff --git a/tests/test_param.py b/tests/test_param.py index 745ace5b..adb4a6d8 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -11,8 +11,6 @@ ) from hssm.defaults import default_model_config -hssm.set_floatX("float64") - def test_param_creation_non_regression(): # Test different param creation strategies @@ -453,6 +451,8 @@ def _check_group_prior(group_prior): param_and_bounds, ) def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds): + # Necessary for verifying the values of certain parameters of the priors + hssm.set_floatX("float64") # Shouldn't do anything if the param is not a regression param_non_reg = Param( name=param_name, @@ -537,7 +537,8 @@ def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds ) param_no_common_intercept.override_default_priors(cavanagh_test, {}) - assert "limitation" in caplog.records[0].msg + print(caplog.records) + assert "limitation" in caplog.records[-1].msg assert "Intercept" not in param_no_common_intercept.prior group_intercept_prior = param_group.prior["1|participant_id"] @@ -546,27 +547,26 @@ def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds _check_group_prior(group_intercept_prior) _check_group_prior(group_slope_prior) + # Change back after testing + hssm.set_floatX("float32") + v_mu = {"name": "Normal", "mu": 2.0, "sigma": 3.0} v_sigma = {"name": "HalfNormal", "sigma": 2.0} v_prior = {"name": "Normal", "mu": v_mu, "sigma": v_sigma} -a_mu = {"name": "Gamma", "alpha": 1.5, "beta": 0.75} +a_mu = {"name": "Gamma", "mu": 1.5, "sigma": 0.75} a_sigma = {"name": "HalfNormal", "sigma": 0.1} -a_prior = {"name": "Gamma", "alpha": a_mu, "beta": a_sigma} +a_prior = {"name": "Gamma", "mu": a_mu, "sigma": a_sigma} -z_mu = {"name": "Normal", "mu": 0.5, "sigma": 0.5} -z_sigma = {"name": "HalfNormal", "sigma": 0.05} +z_mu = {"name": "Gamma", "mu": 10.0, "sigma": 10.0} +z_sigma = {"name": "Gamma", "mu": 10.0, "sigma": 10.0} z_prior = {"name": "Beta", "alpha": z_mu, "beta": z_sigma} -t_mu = {"name": "Gamma", "alpha": 0.4, "beta": 0.2} +t_mu = {"name": "Gamma", "mu": 0.4, "sigma": 0.2} t_sigma = {"name": "HalfNormal", "sigma": 1} t_prior = {"name": "Normal", "mu": t_mu, "sigma": t_sigma} -sv = {"name": "HalfNormal", "sigma": 2.0} -st = {"name": "HalfNormal", "sigma": 0.3} -sz = {"name": "Beta", "alpha": 1.0, "beta": 1.0} - @pytest.mark.parametrize( ("param_name", "mu", "prior"), @@ -580,6 +580,8 @@ def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds def test_param_override_default_priors_ddm( cavanagh_test, caplog, param_name, mu, prior ): + # Necessary for verifying the values of certain parameters of the priors + hssm.set_floatX("float64") # Shouldn't do anything if the param is not a regression param_non_reg = Param( name=param_name, @@ -609,9 +611,7 @@ def test_param_override_default_priors_ddm( assert intercept_prior.name == mu1.pop("name") for key, val in mu1.items(): val1 = intercept_prior._args[key] - if isinstance(val, np.ndarray): - val1 = val1.item() - assert val1 == val + np.testing.assert_almost_equal(val1, val) assert isinstance(slope_prior, bmb.Prior) assert slope_prior.dist is None @@ -676,7 +676,7 @@ def _check_group_prior_intercept_ddm(group_prior, prior): ) param_no_common_intercept.override_default_priors_ddm(cavanagh_test, {}) - assert "limitation" in caplog.records[0].msg + assert "limitation" in caplog.records[-1].msg assert "Intercept" not in param_no_common_intercept.prior group_intercept_prior = param_group.prior["1|participant_id"] @@ -685,6 +685,9 @@ def _check_group_prior_intercept_ddm(group_prior, prior): _check_group_prior_intercept_ddm(group_intercept_prior, prior) _check_group_prior(group_slope_prior) + # Change back after testing + hssm.set_floatX("float32") + def test_hssm_override_default_prior(cavanagh_test): model1 = hssm.HSSM( From 84f1d58d3947174e37f02eda3b6095b2d148cfb1 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 11:23:51 -0500 Subject: [PATCH 18/31] suppress some warnings --- tests/test_hssm.py | 2 +- tests/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 91ecaa5f..6169b44f 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -189,7 +189,7 @@ def test_sample_prior_predictive(data_ddm_reg): ) prior_predictive_5 = model_regression_multi.sample_prior_predictive(draws=10) - data_ddm_reg["subject_id"] = np.arange(10) + data_ddm_reg.loc[:, "subject_id"] = np.arange(10) model_regression_random_effect = HSSM( data=data_ddm_reg, diff --git a/tests/test_utils.py b/tests/test_utils.py index af606ed8..6e629aa8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import pytensor import pytest from ssms.basic_simulators.simulator import simulator -from jax.config import config +from jax import config import hssm from hssm.utils import ( From 7a50254fa2a8339a0be60fdac28bfb85ac2badf3 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 11:28:39 -0500 Subject: [PATCH 19/31] bump ssm-simulators version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 34b3a4ff..a4cda40d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ numpy = ">=1.23.4,<1.26" onnx = "^1.12.0" jax = "^0.4.0" jaxlib = "^0.4.0" -ssm-simulators = "0.5.1" +ssm-simulators = "0.5.3" huggingface-hub = "^0.15.1" onnxruntime = "^1.15.0" bambi = "^0.12.0" From 6c935d003984254a23cebb1a4374162f14163935 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 13:02:06 -0500 Subject: [PATCH 20/31] update ssm-simulators --- mkdocs.yml | 4 ++-- src/hssm/hssm.py | 4 ++-- src/hssm/param.py | 5 ----- src/hssm/utils.py | 2 +- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index c8414a10..304f784b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -126,5 +126,5 @@ markdown_extensions: - pymdownx.superfences - attr_list - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 0e2ef27e..693b135a 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -644,11 +644,11 @@ def plot_trace( data : optional An ArviZ InferenceData object. If None, the traces stored in the model will be used. - include deterministic : optional + include_deterministic : optional Whether to include deterministic variables in the plot. Defaults to False. Note that if include deterministic is set to False and and `var_names` is provided, the `var_names` provided will be modified to also exclude the - deterministic values. If this is not desirable, set + deterministic values. If this is not desirable, set `include deterministic` to True. tight_layout : optional Whether to call plt.tight_layout() after plotting. Defaults to True. diff --git a/src/hssm/param.py b/src/hssm/param.py index eac3ca99..c7c62899 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -318,11 +318,6 @@ def convert(self): if self.formula is not None: # The regression case - - self.formula = ( - self.formula if "~" in self.formula else f"{self.name} ~ {self.formula}" - ) - if isinstance(self.prior, (float, bmb.Prior)): raise ValueError( "Please specify priors for each individual parameter in the " diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 64f84660..f06c3fdc 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -324,7 +324,7 @@ def _process_param_in_kwargs( Raises ------ ValueError - When `prior` is not a `float`, a `dict`, or a `b`mb.Prior` object. + When `prior` is not a `float`, a `dict`, or a `bmb.Prior` object. """ if isinstance(prior, (int, float, bmb.Prior)): return {"name": name, "prior": prior} From d9b28273a93a4206191e09d4e94c7d3574b20d82 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 13:03:11 -0500 Subject: [PATCH 21/31] update ssm-simulators --- mkdocs.yml | 4 ++-- src/hssm/hssm.py | 4 ++-- src/hssm/param.py | 5 ----- src/hssm/utils.py | 2 +- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index c8414a10..304f784b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -126,5 +126,5 @@ markdown_extensions: - pymdownx.superfences - attr_list - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 0e2ef27e..693b135a 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -644,11 +644,11 @@ def plot_trace( data : optional An ArviZ InferenceData object. If None, the traces stored in the model will be used. - include deterministic : optional + include_deterministic : optional Whether to include deterministic variables in the plot. Defaults to False. Note that if include deterministic is set to False and and `var_names` is provided, the `var_names` provided will be modified to also exclude the - deterministic values. If this is not desirable, set + deterministic values. If this is not desirable, set `include deterministic` to True. tight_layout : optional Whether to call plt.tight_layout() after plotting. Defaults to True. diff --git a/src/hssm/param.py b/src/hssm/param.py index eac3ca99..c7c62899 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -318,11 +318,6 @@ def convert(self): if self.formula is not None: # The regression case - - self.formula = ( - self.formula if "~" in self.formula else f"{self.name} ~ {self.formula}" - ) - if isinstance(self.prior, (float, bmb.Prior)): raise ValueError( "Please specify priors for each individual parameter in the " diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 64f84660..f06c3fdc 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -324,7 +324,7 @@ def _process_param_in_kwargs( Raises ------ ValueError - When `prior` is not a `float`, a `dict`, or a `b`mb.Prior` object. + When `prior` is not a `float`, a `dict`, or a `bmb.Prior` object. """ if isinstance(prior, (int, float, bmb.Prior)): return {"name": name, "prior": prior} From f457631b6f05395506c0b57e633665be714a18c6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 13:31:20 -0500 Subject: [PATCH 22/31] fix a test --- tests/test_param.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_param.py b/tests/test_param.py index adb4a6d8..650aacf5 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -286,7 +286,7 @@ def fake_func(x): "x1": bmb.Prior("Normal", mu=0, sigma=0.5), } - param_reg_formula1 = Param("a", formula="1 + x1", prior=priors_dict) + param_reg_formula1 = Param("a", formula="a ~ 1 + x1", prior=priors_dict) param_reg_formula2 = Param( "a", formula="a ~ 1 + x1", prior=priors_dict, link=fake_link ) From 509c0e241852248c5c6756e9533e550cd419bf98 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 7 Dec 2023 13:53:00 -0500 Subject: [PATCH 23/31] set default init to --- src/hssm/hssm.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 693b135a..5d5dbbbe 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -337,6 +337,7 @@ def sample( self, sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"] | None = None, + init: str | None = None, **kwargs, ) -> az.InferenceData | pm.Approximation: """Perform sampling using the `fit` method via bambi.Model. @@ -350,6 +351,9 @@ def sample( sampler will automatically be chosen: when the model uses the `approx_differentiable` likelihood, and `jax` backend, "nuts_numpyro" will be used. Otherwise, "mcmc" (the default PyMC NUTS sampler) will be used. + init: optional + Initialization method to use for the sampler. If any of the NUTS samplers + is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`. kwargs Other arguments passed to bmb.Model.fit(). Please see [here] (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit) @@ -385,7 +389,7 @@ def sample( ) if "step" not in kwargs: - kwargs["step"] = pm.Slice(model=self.pymc_model) + kwargs |= {"step": pm.Slice(model=self.pymc_model)} if ( self.loglik_kind == "approx_differentiable" @@ -402,7 +406,15 @@ def sample( if self._check_extra_fields(): self._update_extra_fields() - self._inference_obj = self.model.fit(inference_method=sampler, **kwargs) + if init is None: + if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: + init = "adapt_diag" + else: + init = "auto" + + self._inference_obj = self.model.fit( + inference_method=sampler, init=init, **kwargs + ) return self.traces From 2cb9608d7dc348e371a03c0d1706039101036f1a Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 12 Dec 2023 09:48:11 -0500 Subject: [PATCH 24/31] bump ssm-simulators --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a4cda40d..8cc2f29e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ numpy = ">=1.23.4,<1.26" onnx = "^1.12.0" jax = "^0.4.0" jaxlib = "^0.4.0" -ssm-simulators = "0.5.3" +ssm-simulators = "0.6.1" huggingface-hub = "^0.15.1" onnxruntime = "^1.15.0" bambi = "^0.12.0" From e0bac720842f83db13ffc6e1a6b3665d55343be8 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 13 Dec 2023 13:08:22 -0500 Subject: [PATCH 25/31] added string representation for generalized logit --- src/hssm/link.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/hssm/link.py b/src/hssm/link.py index 1ad00c10..68e6164e 100644 --- a/src/hssm/link.py +++ b/src/hssm/link.py @@ -78,3 +78,10 @@ def link_(x): return np.log((x - a) / (b - x)) return link_ + + def __str__(self): + """Return a string representation of the link function.""" + if self.name == "gen_logit": + lower, upper = self.bounds + return f"Generalized logit link function with bounds ({lower}, {upper})" + return super().__str__() From 67575a0d65229b9f59566ae02057ab70f7fb3575 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 13 Dec 2023 13:09:04 -0500 Subject: [PATCH 26/31] fixed a bug where link_settings does not work in hssm --- src/hssm/hssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 5d5dbbbe..82fd6512 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -1000,7 +1000,7 @@ def _override_defaults(self): param_obj.override_default_priors( self.data, self.additional_namespace ) - elif self.link_settings == "log_logit": + if self.link_settings == "log_logit": param_obj.override_default_link() def _process_all(self): From 6ae87ad2fe07c282eee27b455d640ffecdb6cf9f Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 13 Dec 2023 13:09:35 -0500 Subject: [PATCH 27/31] added documentation for GPU support --- docs/getting_started/installation.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/docs/getting_started/installation.md b/docs/getting_started/installation.md index 2b495e41..7c5264b4 100644 --- a/docs/getting_started/installation.md +++ b/docs/getting_started/installation.md @@ -37,6 +37,31 @@ a dependency by default. You need to have `blackjax` installed if you want to us pip install blackjax ``` +### Sampling with JAX support for GPU + +The `nuts_numpyro` sampler uses JAX as the backend and thus can support sampling on nvidia +GPU. The only thing you need to do to take advantage of this is to install JAX with CUDA +support before installing HSSM. Here's one example: + +```bash +python -m venv .venv # Create a virtual environment +source .venv/bin/activate # Activate the virtual environment + +pip install --upgrade pip + +# We need to limit the version of JAX for now due to some breaking +# changes introduced in JAX 0.4.16. +pip install --upgrade "jax[cuda11_pip]<0.4.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install hssm +``` + +The example above shows how to install JAX with CUDA 11 support. Please refer to the +[JAX Installation](https://jax.readthedocs.io/en/latest/installation.html) page for more +details on installing JAX on different platforms with GPU or TPU support. + +Note that on Google Colab, JAX support for GPU is enabled by default if the Colab backend +has GPU enabled. You simply need only install HSSM. + ### Visualizing the model Model graphs are created with `model.graph()` through `graphviz`. In order to use it, From 820ff008d704bb394ec40ee597c49fe357a197fe Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 13 Dec 2023 15:40:44 -0500 Subject: [PATCH 28/31] fix bugs in param.py --- src/hssm/param.py | 10 +++++----- src/hssm/prior.py | 2 ++ tests/test_param.py | 28 ++++++++++++++++++++++------ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/hssm/param.py b/src/hssm/param.py index c7c62899..175920bd 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -120,7 +120,7 @@ def override_default_link(self): return elif lower == 0.0 and np.isposinf(upper): self.link = "log" - if not np.isneginf(lower) and not np.isposinf(upper): + elif not np.isneginf(lower) and not np.isposinf(upper): self.link = Link("gen_logit", bounds=self.bounds) else: _logger.warning( @@ -168,7 +168,7 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): if term.kind == "intercept": if has_common_intercept: override_priors[name] = get_default_prior( - "group_intercept", None + "group_intercept_with_common", bounds=None ) else: # treat the term as any other group-specific term @@ -178,7 +178,7 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): + " This will change in the future." ) override_priors[name] = get_default_prior( - "group_specific", bounds=None + "group_intercept", bounds=None ) else: override_priors[name] = get_default_prior( @@ -230,8 +230,8 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An for name, term in dm.group.terms.items(): if term.kind == "intercept": if has_common_intercept: - override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=None + override_priors[name] = get_default_prior( + "group_intercept_with_common", bounds=None ) else: # treat the term as any other group-specific term diff --git a/src/hssm/prior.py b/src/hssm/prior.py index 4f44d6ac..4fa55d77 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -255,6 +255,8 @@ def get_default_prior(term_type: str, bounds: tuple[float, float] | None): prior = generate_prior("Normal", mu="Normal", sigma="Weibull") elif term_type == "group_specific": prior = generate_prior("Normal", mu="Normal", sigma="Weibull") + elif term_type == "group_intercept_with_common": + prior = generate_prior("Normal", mu=0.0, sigma="Weibull") else: raise ValueError("Unrecognized term type.") return prior diff --git a/tests/test_param.py b/tests/test_param.py index 650aacf5..306172d2 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -439,6 +439,22 @@ def _check_group_prior(group_prior): assert sigma.args["beta"] == 0.3 +def _check_group_prior_with_common(group_prior): + assert isinstance(group_prior, bmb.Prior) + assert group_prior.dist is None + assert group_prior.name == "Normal" + + mu = group_prior.args["mu"] + sigma = group_prior.args["sigma"] + + assert mu == 0.0 + + assert isinstance(group_prior, bmb.Prior) + assert sigma.name == "Weibull" + assert sigma.args["alpha"] == 1.5 + assert sigma.args["beta"] == 0.3 + + angle_params = default_model_config["angle"]["list_params"] angle_bounds = default_model_config["angle"]["likelihoods"]["approx_differentiable"][ "bounds" @@ -527,7 +543,7 @@ def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds group_intercept_prior = param_group.prior["1|participant_id"] group_slope_prior = param_group.prior["theta|participant_id"] - _check_group_prior(group_intercept_prior) + _check_group_prior_with_common(group_intercept_prior) _check_group_prior(group_slope_prior) param_no_common_intercept = Param( @@ -541,8 +557,8 @@ def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds assert "limitation" in caplog.records[-1].msg assert "Intercept" not in param_no_common_intercept.prior - group_intercept_prior = param_group.prior["1|participant_id"] - group_slope_prior = param_group.prior["theta|participant_id"] + group_intercept_prior = param_no_common_intercept.prior["1|participant_id"] + group_slope_prior = param_no_common_intercept.prior["theta|participant_id"] _check_group_prior(group_intercept_prior) _check_group_prior(group_slope_prior) @@ -666,7 +682,7 @@ def _check_group_prior_intercept_ddm(group_prior, prior): for key2, val2 in val1.items(): assert hyperprior.args[key2] == val2 - _check_group_prior_intercept_ddm(group_intercept_prior, prior) + _check_group_prior_with_common(group_intercept_prior) _check_group_prior(group_slope_prior) param_no_common_intercept = Param( @@ -679,8 +695,8 @@ def _check_group_prior_intercept_ddm(group_prior, prior): assert "limitation" in caplog.records[-1].msg assert "Intercept" not in param_no_common_intercept.prior - group_intercept_prior = param_group.prior["1|participant_id"] - group_slope_prior = param_group.prior["theta|participant_id"] + group_intercept_prior = param_no_common_intercept.prior["1|participant_id"] + group_slope_prior = param_no_common_intercept.prior["theta|participant_id"] _check_group_prior_intercept_ddm(group_intercept_prior, prior) _check_group_prior(group_slope_prior) From e605503715726413a23c3ad2cf3ef3ef2d288e77 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 15 Dec 2023 17:49:53 -0500 Subject: [PATCH 29/31] added documentation for hierachical modeling --- .../hierarchical_modeling.ipynb | 640 ++++++++++++++++++ mkdocs.yml | 2 + 2 files changed, 642 insertions(+) create mode 100644 docs/getting_started/hierarchical_modeling.ipynb diff --git a/docs/getting_started/hierarchical_modeling.ipynb b/docs/getting_started/hierarchical_modeling.ipynb new file mode 100644 index 00000000..04c81337 --- /dev/null +++ b/docs/getting_started/hierarchical_modeling.ipynb @@ -0,0 +1,640 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0aec427c-56d5-48bb-83c6-6bdc0c445ba2", + "metadata": {}, + "source": [ + "# Hierarchical Modeling\n", + "\n", + "This tutorial demonstrates how to take advantage of HSSM's hierarchical modeling capabilities. We will cover the following:\n", + "\n", + "- How to define a mixed-effect regression\n", + "- How to define a hierarchial HSSM model\n", + "- How to apply prior and link function settings to ensure successful sampling" + ] + }, + { + "cell_type": "markdown", + "id": "36aafcf1-e703-40e3-b11d-4ab5ad74655f", + "metadata": {}, + "source": [ + "## Colab Instructions\n", + "\n", + "If you would like to run this tutorial on Google colab, please click this [link](https://github.com/lnccbrown/HSSM/blob/main/docs/tutorial_notebooks/no_execute/getting_started.ipynb). \n", + "\n", + "Once you are *in the colab*, follow the *installation instructions below* and then **restart your runtime**. \n", + "\n", + "Just **uncomment the code in the next code cell** and run it!\n", + "\n", + "**NOTE**:\n", + "\n", + "You may want to *switch your runtime* to have a GPU or TPU. To do so, go to *Runtime* > *Change runtime type* and select the desired hardware accelerator.\n", + "\n", + "Note that if you switch your runtime you have to follow the installation instructions again." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "61937b47-810d-41b6-a6b8-e461c5e5ae71", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install hssm" + ] + }, + { + "cell_type": "markdown", + "id": "650b011a-62b3-4243-9ed7-2087b2f232cd", + "metadata": {}, + "source": [ + "## Import Modules" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "11fc424e-2aff-49b0-b1a9-d54c7d7f67be", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib as plt\n", + "\n", + "import hssm\n", + "\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format='retina'" + ] + }, + { + "cell_type": "markdown", + "id": "d671ab29-710a-47ab-af79-32fac7891318", + "metadata": {}, + "source": [ + "### Setting the global float type\n", + "\n", + "**Note**: Using the analytical DDM (Drift Diffusion Model) likelihood in PyMC without setting the float type in `PyTensor` may result in warning messages during sampling, which is a known bug in PyMC v5.6.0 and earlier versions. To avoid these warnings, we provide a convenience function:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "01314abb-6ee5-4fc5-975e-002768fde007", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting PyTensor floatX type to float32.\n", + "Setting \"jax_enable_x64\" to False. If this is not intended, please set `jax` to False.\n" + ] + } + ], + "source": [ + "hssm.set_floatX(\"float32\")" + ] + }, + { + "cell_type": "markdown", + "id": "c2fcd33e-46e0-41fe-8fd6-26eedd39aec0", + "metadata": {}, + "source": [ + "## 1. Defining Regressions\n", + "\n", + "Under the hood, HSSM uses [`bambi`](https://bambinos.github.io/bambi/) for model creation. `bambi` takes inspiration from the [`lme4` package in R](https://www.rdocumentation.org/packages/lme4/versions/1.1-35.1/topics/lmer) and supports the definition of generalized linear mixed-effect models through\n", + "R-like formulas and concepts such as link functions. This makes it possible to create arbitrary mixed-effect regressions in HSSM, which is one advantage of HSSM over HDDM. Now let's walk through the ways to define a parameter with a regression in HSSM.\n", + "\n", + "### Specifying fixed- and random-effect terms\n", + "\n", + "Suppose that we want to define a parameter `v` that has a regression defined. There are two ways to define such a parameter - either through a dictionary\n", + "or through a `hssm.Param` object:\n", + "\n", + "```\n", + "# The following code are equivalent,\n", + "# including the definition of the formula.\n", + "\n", + "# The dictionary way:\n", + "param_v = {\n", + " \"name\": \"v\",\n", + " \"formula\": \"v ~ (1|participant_id) + x + y + x:y\",\n", + " \"link\": \"identity\",\n", + " \"prior\": {\n", + " \"Intercept\": {\"name\": \"Normal\", \"mu\": 0.0, \"sigma\": 0.25},\n", + " \"1|participant_id\": {\n", + " \"name\": \"Normal\",\n", + " \"mu\": 0.0,\n", + " \"sigma\": {\"name\": \"HalfNormal\", \"sigma\": 0.2}, # this is a hyperprior\n", + " },\n", + " \"x\": {\"name\": \"Normal\", \"mu\": 0.0, \"sigma\": 0.25},\n", + " },\n", + "}\n", + "\n", + "# The object-oriented way\n", + "param_v = hssm.Param(\n", + " \"v\",\n", + " formula=\"v ~ 1 + (1|participant_id) + x*y\",\n", + " link=\"identity\",\n", + " prior={\n", + " \"Intercept\": hssm.Prior(\"Normal\", mu=0.0, sigma=0.25),\n", + " \"1|participant_id\": hssm.Prior(\n", + " \"Normal\",\n", + " mu=0.0,\n", + " sigma=hssm.Prior(\"HalfNormal\", sigma=0.2), # this is a hyperprior\n", + " ),\n", + " \"x\": hssm.Prior(\"Normal\", mu=0.0, sigma=0.25),\n", + " },\n", + ")\n", + "```\n", + "\n", + "The formula `\"v ~ (1|participant_id) + x + y + x:y\"` defines a random-intercept model. Like R, unless otherwise specified, a fixed-effect intercept term is added to the formula by default. You can make this explicit by adding a `1` to the formula. Or, if your regression does not have an intercept. you can explicitly remove the intercept term by using a `0` in the place of `1`: `\"v ~ 0 + (1|participant_id) + x * y\"`.\n", + "\n", + "Other fixed effect covariates are `x`, `y`, and the interaction term `x:y`. When all three terms are present, you can use the shortcut `x * y` in place of the three terms.\n", + "\n", + "The only random effect term in this model is `1|participant_id`. It is a random-intercept term with `participant_id` indicating the grouping variable. You can add another random-effect term in a similar way: `\"v ~ (1|participant_id) + (x|participant_id) + x + y + x:y\"`, or more briefly, `\"v ~ (1 + x|participant_id) + x + y + x:y\"`.\n", + "\n", + "### Specifying priors for fixed- and random-effect terms:\n", + "\n", + "As demonstrated in the above code, you can specify priors of each term through a dictionary, with the key being the name of each term, and the corresponding value being the prior specification, etiher through a dictionary, or a `hssm.Prior` object. There are a few things to note:\n", + "\n", + "* The prior of fixed-effect intercept is specified with `\"Intercept\"`, capitalized.\n", + "* For random effects, you can specify hyperpriors for the parameters of of their priors.\n", + "\n", + "### Specifying the link functions:\n", + "\n", + "Link functions is another concept in frequentist generalized linear models, which defines a transformation between the linear combination of the covariates and the response variable. This is helpful especially when the response variable is not normally distributed, e.g. in a logistic regression. In HSSM, the link function is identity by default. However, since some parameters of SSMs are defined on `(0, inf)` or `(0, 1)`, link function can be helpful in ensuring the result of the regression is defined for these parameters. We will come back to this later." + ] + }, + { + "cell_type": "markdown", + "id": "3cb603f1-dc4f-44d3-b12f-e4a460d2d9f9", + "metadata": {}, + "source": [ + "## 2. Defining a hierarchical HSSM model\n", + "\n", + "In fact, HSSM does not differentiate between a hierarchical or non-hierarchical model. A hierarchical model in HSSM is simply a model with one or more parameters defined as regressions. However, HSSM does provide some useful functionalities in creating hierarchical models.\n", + "\n", + "### Clarifying the use of `hierarchical` argument during model creation\n", + "\n", + "First, HSSM has a `hierarchical` argument which is a `bool`. It serves as a convenient switch to add a random-intercept regression to any parameter that is not explicitly defined by the user, using `participant_id` as a grouping variable. If there is not a `participant_id` column in the data, setting `hierarchical` to `True` will raise an error. Setting `hierarchical` to True will also change some default behavior in HSSM. Here's an example:\n", + "\n", + "
\n", + "

Note

\n", + "

\n", + " In HSSM, the default grouping variable is now `participant_id`, which is different from `subj_idx` in HDDM.\n", + "

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cab2a960-e9e6-4043-996c-57742832de0d", + "metadata": {}, + "outputs": [], + "source": [ + "# Load a package-supplied dataset\n", + "cav_data = hssm.load_data(\"cavanagh_theta\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "28c75438-c6c4-4589-8244-dac7ffc18eb5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: analytical\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Prior: Normal(mu: 0.0, sigma: 2.0)\n", + " Explicit bounds: (-inf, inf)\n", + "a:\n", + " Prior: HalfNormal(sigma: 2.0)\n", + " Explicit bounds: (0.0, inf)\n", + "z:\n", + " Prior: Uniform(lower: 0.0, upper: 1.0)\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Prior: HalfNormal(sigma: 2.0, initval: 0.10000000149011612)\n", + " Explicit bounds: (0.0, inf)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define a basic non-hierarchical model\n", + "model_non_hierarchical = hssm.HSSM(data=cav_data)\n", + "model_non_hierarchical" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7cde1989-8d6b-437a-a972-940f0fd84904", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: analytical\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Formula: v ~ 1 + (1|participant_id)\n", + " Priors:\n", + " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n", + " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (-inf, inf)\n", + "a:\n", + " Formula: a ~ 1 + (1|participant_id)\n", + " Priors:\n", + " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n", + " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, inf)\n", + "z:\n", + " Formula: z ~ 1 + (1|participant_id)\n", + " Priors:\n", + " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n", + " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Formula: t ~ 1 + (1|participant_id)\n", + " Priors:\n", + " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n", + " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, inf)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now let's set `hierarchical` to True\n", + "model_hierarchical = hssm.HSSM(data=cav_data, hierarchical=True, prior_settings=\"safe\")\n", + "model_hierarchical" + ] + }, + { + "cell_type": "markdown", + "id": "e4baeb2f-c49e-4d63-a962-4c6f47f3f848", + "metadata": {}, + "source": [ + "## 3. Intelligent defaults for complex hierarchical models\n", + "\n", + "`bambi` is not designed with HSSM in mind. Therefore, in cases where priors for certain parameters are not defined, the default priors supplied by `bambi` sometimes are not optimal. The same goes for link functions. `\"identity\"` link functions tend not to work well for certain parameters that are not defined on `(inf, inf)`. Therefore, we provide some default settings that the users can experiment to ensure that sampling is successful." + ] + }, + { + "cell_type": "markdown", + "id": "11b405fc-d508-468c-8fd8-12283ed03945", + "metadata": {}, + "source": [ + "### `prior_settings`\n", + "\n", + "Currently we provide a `\"safe\"` strategy that uses HSSM default priors. This is turned on by default when `hierarchical` is set to `True`. One can compare the two models below, with `safe` strategy turned on and off:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "145545e3-4712-4929-84c5-53ab3f8ae051", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: approx_differentiable\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Formula: v ~ 1 + (1|participant_id)\n", + " Priors:\n", + " v_Intercept ~ Normal(mu: 0.0, sigma: 0.25)\n", + " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (-3.0, 3.0)\n", + "a:\n", + " Formula: a ~ 1 + (1|participant_id)\n", + " Priors:\n", + " a_Intercept ~ Normal(mu: 1.399999976158142, sigma: 0.25)\n", + " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.3, 2.5)\n", + "z:\n", + " Formula: z ~ 1 + (1|participant_id)\n", + " Priors:\n", + " z_Intercept ~ Normal(mu: 0.5, sigma: 0.25)\n", + " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Formula: t ~ 1 + (1|participant_id)\n", + " Priors:\n", + " t_Intercept ~ Normal(mu: 1.0, sigma: 0.25)\n", + " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, 2.0)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_safe = hssm.HSSM(\n", + " data=cav_data,\n", + " hierarchical=True,\n", + " prior_settings=\"safe\",\n", + " loglik_kind=\"approx_differentiable\",\n", + ")\n", + "model_safe" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8ea21e3d-e5e4-4bf3-a5a1-7aea1e9d9d00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: approx_differentiable\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Formula: v ~ 1 + (1|participant_id)\n", + " Priors:\n", + " v_Intercept ~ Normal(mu: 0.0, sigma: 0.25)\n", + " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (-3.0, 3.0)\n", + "a:\n", + " Formula: a ~ 1 + (1|participant_id)\n", + " Priors:\n", + " a_Intercept ~ Normal(mu: 1.399999976158142, sigma: 0.25)\n", + " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.3, 2.5)\n", + "z:\n", + " Formula: z ~ 1 + (1|participant_id)\n", + " Priors:\n", + " z_Intercept ~ Normal(mu: 0.5, sigma: 0.25)\n", + " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Formula: t ~ 1 + (1|participant_id)\n", + " Priors:\n", + " t_Intercept ~ Normal(mu: 1.0, sigma: 0.25)\n", + " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (0.0, 2.0)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_safe_off = hssm.HSSM(\n", + " data=cav_data,\n", + " hierarchical=True,\n", + " prior_settings=None,\n", + " loglik_kind=\"approx_differentiable\",\n", + ")\n", + "model_safe_off" + ] + }, + { + "cell_type": "markdown", + "id": "c99b6d77-b4a2-4813-8927-433b65d646a3", + "metadata": {}, + "source": [ + "### `link_settings`\n", + "\n", + "We also provide a `link_settings` switch, which changes default link functions for parameters according to their explicit bounds. See the model below with `link_settings` set to `\"log_logit\"`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "67d0f895-8188-4e44-8946-ef80af2c4b67", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: analytical\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Formula: v ~ 1 + (1|participant_id)\n", + " Priors:\n", + " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n", + " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (-inf, inf)\n", + "a:\n", + " Formula: a ~ 1 + (1|participant_id)\n", + " Priors:\n", + " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n", + " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: log\n", + " Explicit bounds: (0.0, inf)\n", + "z:\n", + " Formula: z ~ 1 + (1|participant_id)\n", + " Priors:\n", + " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n", + " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: Generalized logit link function with bounds (0.0, 1.0)\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Formula: t ~ 1 + (1|participant_id)\n", + " Priors:\n", + " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n", + " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: log\n", + " Explicit bounds: (0.0, inf)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_log_logit = hssm.HSSM(\n", + " data=cav_data, hierarchical=True, prior_settings=None, link_settings=\"log_logit\"\n", + ")\n", + "model_log_logit" + ] + }, + { + "cell_type": "markdown", + "id": "bc82d284-7164-4072-9a25-67fa8cc77b17", + "metadata": {}, + "source": [ + "### Mixing strategies:\n", + "\n", + "It is possible to turn on both `prior_settings` and `link_settings`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a6099bb5-2d55-4ef8-b08b-cee2edfa4bc7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Hierarchical Sequential Sampling Model\n", + "Model: ddm\n", + "\n", + "Response variable: rt,response\n", + "Likelihood: analytical\n", + "Observations: 3988\n", + "\n", + "Parameters:\n", + "\n", + "v:\n", + " Formula: v ~ 1 + (1|participant_id)\n", + " Priors:\n", + " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n", + " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: identity\n", + " Explicit bounds: (-inf, inf)\n", + "a:\n", + " Formula: a ~ 1 + (1|participant_id)\n", + " Priors:\n", + " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n", + " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: log\n", + " Explicit bounds: (0.0, inf)\n", + "z:\n", + " Formula: z ~ 1 + (1|participant_id)\n", + " Priors:\n", + " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n", + " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: Generalized logit link function with bounds (0.0, 1.0)\n", + " Explicit bounds: (0.0, 1.0)\n", + "t:\n", + " Formula: t ~ 1 + (1|participant_id)\n", + " Priors:\n", + " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n", + " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n", + " Link: log\n", + " Explicit bounds: (0.0, inf)\n", + "\n", + "Lapse probability: 0.05\n", + "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_safe_loglogit = hssm.HSSM(\n", + " data=cav_data, hierarchical=True, prior_settings=\"safe\", link_settings=\"log_logit\"\n", + ")\n", + "model_safe_loglogit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "762e94e1-66be-47c6-9024-59047790953a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 304f784b..11e5f9f7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,6 +12,7 @@ nav: - Getting Started: - Installation: getting_started/installation.md - Getting started: getting_started/getting_started.ipynb + - Hierarchical modeling: getting_started/hierarchical_modeling.ipynb - API References: - hssm: api/hssm.md - hssm.plotting: api/plotting.md @@ -33,6 +34,7 @@ plugins: execute: true execute_ignore: - getting_started/getting_started.ipynb + - getting_started/hierarchical_modeling.ipynb - tutorials/main_tutorial.ipynb - tutorials/likelihoods.ipynb - .ipynb_checkpoints/*.ipynb From 0d4c50738f9f9d8714cf6485670834c1a735eb95 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 15 Dec 2023 17:50:09 -0500 Subject: [PATCH 30/31] added changelog --- docs/changelog.md | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index d641dcc2..18621e2c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,31 @@ # Changelog +## 0.2.x + +### 0.2.0 + +This is a major version update! Many changes have taken place in this version: + +#### Breaking changes + +When `hierarchical` argument of `hssm.HSSM` is set to `True`, HSSM will look into the +`data` provided for the `participant_id` field. If it does not exist, an error will +be thrown. + +### New features + +- Added `link_settings` and `prior_settings` arguments to `hssm.HSSM`, which allows HSSM + to use intelligent default priors and link functions for complex hierarchical models. + +- Added an `hssm.plotting` submodule with `plot_posterior_predictive()` and + `plot_quantile_probability` for creating posterior predictive plots and quantile + probability plots. + +- Added an `extra_fields` argument to `hssm.HSSM` to pass additional data to the + likelihood function computation. + +- Limited `PyMC`, `pytensor`, `numpy`, and `jax` dependency versions for compatibility. + ## 0.1.x ### 0.1.5 @@ -8,8 +34,8 @@ We fixed the errors in v0.1.4. Sorry for the convenience! If you have accidental downloaded v0.1.4, please make sure that you update hssm to the current version. - We made Cython dependencies of this package available via pypi. We have also built -wheels for (almost) all platforms so there is no need to build these Cython -dependencies. + wheels for (almost) all platforms so there is no need to build these Cython + dependencies. ### 0.1.4 From 32b6af1e914e90555243aebde23db3f2b99cc202 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 15 Dec 2023 17:53:59 -0500 Subject: [PATCH 31/31] changed version to 0.2.0b1 --- docs/overrides/main.html | 30 +++++++++++++----------------- pyproject.toml | 2 +- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/overrides/main.html b/docs/overrides/main.html index 38eda562..81fe76d3 100644 --- a/docs/overrides/main.html +++ b/docs/overrides/main.html @@ -1,22 +1,18 @@ -{% extends "base.html" %} - -{% block announce %} +{% extends "base.html" %} {% block announce %} - - {% include ".icons/fontawesome/solid/angles-down.svg" %} - - Navigate the site here! - - - v0.1.5 is released! + + {% include ".icons/fontawesome/solid/angles-down.svg" %} + + Navigate the site here! + v0.2.0b1 is released! - - {% include ".icons/material/head-question.svg" %} - - Questions? - - Open a discussion here! - + + {% include ".icons/material/head-question.svg" %} + + Questions? + + Open a discussion here! + {% endblock %} diff --git a/pyproject.toml b/pyproject.toml index 5c85796f..c963895a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "HSSM" -version = "0.2.0" +version = "0.2.0b1" description = "Bayesian inference for hierarchical sequential sampling models." authors = [ "Alexander Fengler ",