Skip to content

Commit

Permalink
Added UserParam class for user specification of parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Aug 30, 2024
1 parent fb78fd8 commit 5ca909a
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
The module contains the classes useful for specifying and storing parameter
specifications.
"""

from .user_param import UserParam

__all__ = ["UserParam"]
90 changes: 90 additions & 0 deletions src/hssm/param_refactor/user_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""The Param class is a container for user-specified parameters of the HSSM model."""

from dataclasses import dataclass
from typing import Any, Union

import bambi as bmb
import numpy as np


@dataclass
class UserParam:
"""Represent the user-provided specifications for the main HSSM class.
Also provides convenience functions that can be used by the HSSM class to parse
arguments.
Parameters
----------
name
The name of the parameter. This can be omitted if the Param is specified as
kwargs in the HSSM class.
prior
If a formula is not specified (the non-regression case), this parameter
expects a float value if the parameter is fixed or a dictionary that can be
parsed by Bambi as a prior specification or a Bambi Prior object.
If a formula is specified (the regression case), this parameter expects a
dictionary of param:prior, where param is the name of the response variable
specified in formula, and prior is specified as above. If left unspecified,
default priors created by Bambi will be used.
formula
The regression formula if the parameter depends on other variables.
link
The link function for the regression. It is either a string that specifies
a built-in link function in Bambi, or a Bambi Link object. If a regression
is specified and link is not specified, "identity" will be used by default.
bounds
If provided, the prior will be created with boundary checks. If this
parameter is specified as a regression, boundary checks will be skipped at
this point.
"""

name: str | None = None
prior: int | float | np.ndarray | dict[str, Any] | bmb.Prior = None
formula: str | None = None
link: str | bmb.Link | None = None
bounds: tuple[int, int] | None = None

@staticmethod
def from_dict(param_dict: dict[str, Any]) -> "UserParam":
"""Create a Param object from a dictionary.
Parameters
----------
param_dict
A dictionary with the keys "name", "prior", "formula", "link", and "bounds".
Returns
-------
Param
A Param object with the specified parameters.
"""
return UserParam(**param_dict)

@staticmethod
def from_kwargs(
name: str,
# Using Union here because "UserParam" is a forward reference
param: Union[int, float, np.ndarray, dict[str, Any], bmb.Prior, "UserParam"],
) -> "UserParam":
"""Create a Param object from keyword arguments.
Parameters
----------
name
The name of the parameter.
param
The prior specification for the parameter.
Returns
-------
Param
A Param object with the specified parameters.
"""
if isinstance(param, dict):
return UserParam(name=name, **param)
elif isinstance(param, UserParam):
param.name = name
return param

return UserParam(name=name, prior=param)
78 changes: 78 additions & 0 deletions tests/param/test_user_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import bambi as bmb
from hssm.param_refactor.user_param import UserParam


def test_user_param_initialization():
param = UserParam(
name="param1", prior=0.5, formula="y ~ x", link="identity", bounds=(-1, 1)
)
assert param.name == "param1"
assert param.prior == 0.5
assert param.formula == "y ~ x"
assert param.link == "identity"
assert param.bounds == (-1, 1)


def test_user_param_from_dict():
param_dict = {
"name": "param2",
"prior": {"mu": 0, "sigma": 1},
"formula": "y ~ z",
"link": "logit",
"bounds": (0, 1),
}
param = UserParam.from_dict(param_dict)
assert param.name == "param2"
assert param.prior == {"mu": 0, "sigma": 1}
assert param.formula == "y ~ z"
assert param.link == "logit"
assert param.bounds == (0, 1)


def test_user_param_from_kwargs_with_dict():
param = UserParam.from_kwargs(
name="param3",
param={
"prior": {"mu": 0, "sigma": 1},
"formula": "y ~ w",
"link": "probit",
"bounds": (0, 1),
},
)
assert param.name == "param3"
assert param.prior == {"mu": 0, "sigma": 1}
assert param.formula == "y ~ w"
assert param.link == "probit"
assert param.bounds == (0, 1)


def test_user_param_from_kwargs_with_user_param():
existing_param = UserParam(
name="param4", prior=1.0, formula="y ~ v", link="log", bounds=(-2, 2)
)
param = UserParam.from_kwargs(name="param5", param=existing_param)
assert param.name == "param5"
assert param.prior == 1.0
assert param.formula == "y ~ v"
assert param.link == "log"
assert param.bounds == (-2, 2)


def test_user_param_from_kwargs_with_prior():
bambi_prior = bmb.Prior("Normal", mu=0, sigma=1)
param = UserParam.from_kwargs(name="param6", param=bambi_prior)
assert param.name == "param6"
assert param.prior == bambi_prior
assert param.formula is None
assert param.link is None
assert param.bounds is None


def test_user_param_from_kwargs_constant():
constant = 0.0
param = UserParam.from_kwargs(name="param6", param=constant)
assert param.name == "param6"
assert param.prior == constant
assert param.formula is None
assert param.link is None
assert param.bounds is None

0 comments on commit 5ca909a

Please sign in to comment.