-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added UserParam class for user specification of parameters
- Loading branch information
1 parent
fb78fd8
commit 5ca909a
Showing
3 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""The Param class is a container for user-specified parameters of the HSSM model.""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Union | ||
|
||
import bambi as bmb | ||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class UserParam: | ||
"""Represent the user-provided specifications for the main HSSM class. | ||
Also provides convenience functions that can be used by the HSSM class to parse | ||
arguments. | ||
Parameters | ||
---------- | ||
name | ||
The name of the parameter. This can be omitted if the Param is specified as | ||
kwargs in the HSSM class. | ||
prior | ||
If a formula is not specified (the non-regression case), this parameter | ||
expects a float value if the parameter is fixed or a dictionary that can be | ||
parsed by Bambi as a prior specification or a Bambi Prior object. | ||
If a formula is specified (the regression case), this parameter expects a | ||
dictionary of param:prior, where param is the name of the response variable | ||
specified in formula, and prior is specified as above. If left unspecified, | ||
default priors created by Bambi will be used. | ||
formula | ||
The regression formula if the parameter depends on other variables. | ||
link | ||
The link function for the regression. It is either a string that specifies | ||
a built-in link function in Bambi, or a Bambi Link object. If a regression | ||
is specified and link is not specified, "identity" will be used by default. | ||
bounds | ||
If provided, the prior will be created with boundary checks. If this | ||
parameter is specified as a regression, boundary checks will be skipped at | ||
this point. | ||
""" | ||
|
||
name: str | None = None | ||
prior: int | float | np.ndarray | dict[str, Any] | bmb.Prior = None | ||
formula: str | None = None | ||
link: str | bmb.Link | None = None | ||
bounds: tuple[int, int] | None = None | ||
|
||
@staticmethod | ||
def from_dict(param_dict: dict[str, Any]) -> "UserParam": | ||
"""Create a Param object from a dictionary. | ||
Parameters | ||
---------- | ||
param_dict | ||
A dictionary with the keys "name", "prior", "formula", "link", and "bounds". | ||
Returns | ||
------- | ||
Param | ||
A Param object with the specified parameters. | ||
""" | ||
return UserParam(**param_dict) | ||
|
||
@staticmethod | ||
def from_kwargs( | ||
name: str, | ||
# Using Union here because "UserParam" is a forward reference | ||
param: Union[int, float, np.ndarray, dict[str, Any], bmb.Prior, "UserParam"], | ||
) -> "UserParam": | ||
"""Create a Param object from keyword arguments. | ||
Parameters | ||
---------- | ||
name | ||
The name of the parameter. | ||
param | ||
The prior specification for the parameter. | ||
Returns | ||
------- | ||
Param | ||
A Param object with the specified parameters. | ||
""" | ||
if isinstance(param, dict): | ||
return UserParam(name=name, **param) | ||
elif isinstance(param, UserParam): | ||
param.name = name | ||
return param | ||
|
||
return UserParam(name=name, prior=param) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import bambi as bmb | ||
from hssm.param_refactor.user_param import UserParam | ||
|
||
|
||
def test_user_param_initialization(): | ||
param = UserParam( | ||
name="param1", prior=0.5, formula="y ~ x", link="identity", bounds=(-1, 1) | ||
) | ||
assert param.name == "param1" | ||
assert param.prior == 0.5 | ||
assert param.formula == "y ~ x" | ||
assert param.link == "identity" | ||
assert param.bounds == (-1, 1) | ||
|
||
|
||
def test_user_param_from_dict(): | ||
param_dict = { | ||
"name": "param2", | ||
"prior": {"mu": 0, "sigma": 1}, | ||
"formula": "y ~ z", | ||
"link": "logit", | ||
"bounds": (0, 1), | ||
} | ||
param = UserParam.from_dict(param_dict) | ||
assert param.name == "param2" | ||
assert param.prior == {"mu": 0, "sigma": 1} | ||
assert param.formula == "y ~ z" | ||
assert param.link == "logit" | ||
assert param.bounds == (0, 1) | ||
|
||
|
||
def test_user_param_from_kwargs_with_dict(): | ||
param = UserParam.from_kwargs( | ||
name="param3", | ||
param={ | ||
"prior": {"mu": 0, "sigma": 1}, | ||
"formula": "y ~ w", | ||
"link": "probit", | ||
"bounds": (0, 1), | ||
}, | ||
) | ||
assert param.name == "param3" | ||
assert param.prior == {"mu": 0, "sigma": 1} | ||
assert param.formula == "y ~ w" | ||
assert param.link == "probit" | ||
assert param.bounds == (0, 1) | ||
|
||
|
||
def test_user_param_from_kwargs_with_user_param(): | ||
existing_param = UserParam( | ||
name="param4", prior=1.0, formula="y ~ v", link="log", bounds=(-2, 2) | ||
) | ||
param = UserParam.from_kwargs(name="param5", param=existing_param) | ||
assert param.name == "param5" | ||
assert param.prior == 1.0 | ||
assert param.formula == "y ~ v" | ||
assert param.link == "log" | ||
assert param.bounds == (-2, 2) | ||
|
||
|
||
def test_user_param_from_kwargs_with_prior(): | ||
bambi_prior = bmb.Prior("Normal", mu=0, sigma=1) | ||
param = UserParam.from_kwargs(name="param6", param=bambi_prior) | ||
assert param.name == "param6" | ||
assert param.prior == bambi_prior | ||
assert param.formula is None | ||
assert param.link is None | ||
assert param.bounds is None | ||
|
||
|
||
def test_user_param_from_kwargs_constant(): | ||
constant = 0.0 | ||
param = UserParam.from_kwargs(name="param6", param=constant) | ||
assert param.name == "param6" | ||
assert param.prior == constant | ||
assert param.formula is None | ||
assert param.link is None | ||
assert param.bounds is None |