Skip to content

Commit

Permalink
Merge pull request #234 from lnccbrown/170-alignment-of-model_config-…
Browse files Browse the repository at this point in the history
…entries

Updated the way model configs are handled
  • Loading branch information
digicosmos86 authored Jul 26, 2023
2 parents 2c13fac + 5ca8dd1 commit 01f0cd9
Show file tree
Hide file tree
Showing 10 changed files with 371 additions and 218 deletions.
8 changes: 2 additions & 6 deletions src/hssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import sys

from .config import show_defaults
from .datasets import load_data
from .hssm import HSSM
from .simulator import simulate_data
Expand All @@ -13,9 +14,4 @@
handler = logging.StreamHandler(stream=sys.stdout)
_logger.addHandler(handler)

__all__ = [
"HSSM",
"load_data",
"simulate_data",
"set_floatX",
]
__all__ = ["HSSM", "load_data", "simulate_data", "set_floatX", "show_defaults"]
119 changes: 103 additions & 16 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Provide default configurations for models in the HSSM class."""
from typing import Any, Literal
from typing import Any, Literal, Optional

import bambi as bmb

from .likelihoods.analytical import (
ddm_bounds,
Expand All @@ -8,6 +10,7 @@
logp_ddm,
logp_ddm_sdv,
)
from .param import _make_default_prior

SupportedModels = Literal[
"ddm",
Expand Down Expand Up @@ -39,15 +42,11 @@
"loglik": logp_ddm,
"bounds": ddm_bounds,
"default_priors": {
"v": {"name": "Uniform", "lower": -10.0, "upper": 10.0},
"a": {"name": "HalfNormal", "sigma": 2.0},
"z": None,
"t": {"name": "Uniform", "lower": 0.0, "upper": 2.0, "initval": 0.1},
},
},
"approx_differentiable": {
"loglik": "ddm.onnx",
"backend": "jax",
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand All @@ -61,16 +60,11 @@
"loglik": logp_ddm_sdv,
"bounds": ddm_bounds,
"default_priors": {
"v": {"name": "Uniform", "lower": -10.0, "upper": 10.0},
"a": {"name": "HalfNormal", "sigma": 2.0},
"z": None,
"t": {"name": "Uniform", "lower": 0.0, "upper": 2.0, "initval": 0.1},
"sv": {"name": "HalfNormal", "sigma": 1.0},
},
},
"approx_differentiable": {
"loglik": "ddm_sdv.onnx",
"backend": "jax",
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand All @@ -83,7 +77,6 @@
"angle": {
"approx_differentiable": {
"loglik": "angle.onnx",
"backend": "jax",
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
Expand All @@ -96,7 +89,6 @@
"levy": {
"approx_differentiable": {
"loglik": "levy.onnx",
"backend": "jax",
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 3.0),
Expand All @@ -109,7 +101,6 @@
"ornstein": {
"approx_differentiable": {
"loglik": "ornstein.onnx",
"backend": "jax",
"bounds": {
"v": (-2.0, 2.0),
"a": (0.3, 3.0),
Expand All @@ -122,7 +113,6 @@
"weibull": {
"approx_differentiable": {
"loglik": "weibull.onnx",
"backend": "jax",
"bounds": {
"v": (-2.5, 2.5),
"a": (0.3, 2.5),
Expand All @@ -136,7 +126,6 @@
"race_no_bias_angle_4": {
"approx_differentiable": {
"loglik": "race_no_bias_angle_4.onnx",
"backend": "jax",
"bounds": {
"v0": (0.0, 2.5),
"v1": (0.0, 2.5),
Expand All @@ -152,7 +141,6 @@
"ddm_seq2_no_bias": {
"approx_differentiable": {
"loglik": "ddm_seq2_no_bias.onnx",
"backend": "jax",
"bounds": {
"vh": (-4.0, 4.0),
"vl1": (-4.0, 4.0),
Expand All @@ -174,3 +162,102 @@
"race_no_bias_angle_4": ["v0", "v1", "v2", "v3", "a", "z", "ndt", "theta"],
"ddm_seq2_no_bias": ["vh", "vl1", "vl2", "a", "t"],
}


def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> str:
"""Show the defaults for supported models.
Parameters
----------
model
One of the supported model strings.
loglik_kind : optional
The kind of likelihood function, by default None, in which case the defaults for
all likelihoods will be shown.
Returns
-------
str
A nicely organized printout for the defaults of provided model.
"""
if model not in default_model_config:
raise ValueError(f"{model} does not currently have defaults in HSSM.")

model_config = default_model_config[model]

output = []
output.append("Default model config:")
output.append(f"Model: {model}")

# Will implement later
# if "description" in model_config:
# output.append("Description:")
# output.append(f" {model_config['description']}")

output.append(f"Default parameters: {default_params[model]}")
output.append("")

if loglik_kind is not None:
if loglik_kind not in model_config:
raise ValueError(
f"{model} does not currently have defaults for `{loglik_kind}` "
+ "log-likelihoods in HSSM."
)

output += _show_defaults_helper(model, loglik_kind)

else:
for loglik_kind in model_config.keys():
output += _show_defaults_helper(model, loglik_kind)
output.append("")

output = output[:-1]

return "\r\n".join(output)


def _show_defaults_helper(model: SupportedModels, loglik_kind: LoglikKind) -> list[str]:
"""Show the defaults for supported models.
Parameters
----------
model
One of the supported model strings.
loglik_kind
The kind of likelihood function.
Returns
-------
list[str]
A list of nicely organized printout for the defaults of provided model.
"""
output = []
params = default_params[model]
model_defaults = default_model_config[model][loglik_kind]

output.append(f"Log-likelihood kind: {loglik_kind}")
output.append(f"Log-likelihood: {model_defaults['loglik']}")
if loglik_kind == "approx_differentiable":
output.append("Default backend: jax")
output.append("Default priors:")

default_priors = model_defaults.get("default_priors", {})
default_bounds = model_defaults.get("bounds", {})

for param in params:
prior = default_priors.get(param, None)
if prior is None:
bounds = default_bounds.get(param, None)
prior = _make_default_prior(bounds) if bounds is not None else None
else:
if isinstance(prior, dict):
prior = bmb.Prior(**prior)

output.append(f" {param} ~ {prior}")

output.append("Default bounds:")

for param in params:
output.append(f" {param}: {default_bounds.get(param, None)}")

return output
2 changes: 1 addition & 1 deletion src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def make_distribution_from_onnx(
rv: str | Type[RandomVariable],
list_params: list[str],
onnx_model: str | PathLike | onnx.ModelProto,
backend: str = "pytensor",
backend: str = "jax",
bounds: dict | None = None,
params_is_reg: list[bool] | None = None,
lapse: bmb.Prior | None = None,
Expand Down
6 changes: 4 additions & 2 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def sample(
if sampler is None:
if (
self.loglik_kind == "approx_differentiable"
and self.model_config["backend"] == "jax"
and self.model_config.get("backend") == "jax"
):
sampler = "nuts_numpyro"
else:
Expand All @@ -535,7 +535,7 @@ def sample(

if (
self.loglik_kind == "approx_differentiable"
and self.model_config["backend"] == "jax"
and self.model_config.get("backend") == "jax"
and sampler == "mcmc"
and kwargs.get("cores", None) != 1
):
Expand Down Expand Up @@ -857,6 +857,8 @@ def _create_param(param: str | dict, model_config: dict, is_parent: bool) -> Par
and "formula" not in model_config
):
prior = model_config["default_priors"][name]
else:
prior = None
else:
prior = param["prior"]
return Param(
Expand Down
12 changes: 7 additions & 5 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pymc as pm
import pytensor
import pytensor.tensor as pt
from numpy import inf
from pymc.distributions.dist_math import check_parameters

from ..distribution_utils.dist import make_distribution
Expand Down Expand Up @@ -357,12 +358,13 @@ def logp_ddm_sdv(
return checked_logp


ddm_bounds = {"z": (0.0, 1.0)}
ddm_sdv_bounds = ddm_bounds | {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
"t": (0.0, 2.0),
ddm_bounds = {
"v": (-inf, inf),
"a": (0.0, inf),
"z": (0.0, 1.0),
"t": (0.0, inf),
}
ddm_sdv_bounds = ddm_bounds | {"sv": (0.0, inf)}

ddm_params = ["v", "a", "z", "t"]
ddm_sdv_params = ddm_params + ["sv"]
Expand Down
50 changes: 44 additions & 6 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Union, cast

import bambi as bmb
import numpy as np
import pymc as pm
from bambi.backend.utils import get_distribution

Expand Down Expand Up @@ -69,6 +70,16 @@ def __init__(
self.bounds = tuple(float(x) for x in bounds) if bounds is not None else None
self._is_truncated = False

if self.bounds is not None:
self.bounds = cast(BoundsSpec, self.bounds)
if any(not np.isscalar(bound) for bound in self.bounds):
raise ValueError(f"The bounds of {self.name} should both be scalar.")
lower, upper = self.bounds
assert lower < upper, (
f"The lower bound of {self.name} should be less than "
+ "its upper bound."
)

if isinstance(prior, int):
prior = float(prior)

Expand All @@ -93,16 +104,16 @@ def __init__(
# The non-regression case

if prior is None:
if bounds is None:
if self.bounds is None:
raise ValueError(
f"Please specify the prior or bounds for {self.name}."
)
self.prior = bmb.Prior(name="Uniform", lower=bounds[0], upper=bounds[1])
self.prior = _make_default_prior(self.bounds)
else:
# Explicitly cast the type of prior, no runtime performance penalty
prior = cast(ParamSpec, prior)

if bounds is None:
if self.bounds is None:
if isinstance(prior, (float, bmb.Prior)):
self.prior = prior
else:
Expand All @@ -111,7 +122,7 @@ def __init__(
if isinstance(prior, float):
self.prior = prior
else:
self.prior = make_bounded_prior(prior, bounds)
self.prior = make_bounded_prior(prior, self.bounds)
# self._prior is internally used for informative output
# Not used in inference
self._prior = (
Expand Down Expand Up @@ -412,6 +423,8 @@ def make_bounded_prior(prior: ParamSpec, bounds: BoundsSpec) -> float | bmb.Prio
A float if `prior` is a float, otherwise a bmb.Prior object.
"""
lower, upper = bounds
if np.isinf(lower) and np.isinf(upper):
return prior

if isinstance(prior, float):
if not lower <= prior <= upper:
Expand Down Expand Up @@ -468,9 +481,34 @@ def TruncatedDist(name):
return pm.Truncated(
name=name,
dist=dist,
lower=lower_bound,
upper=upper_bound,
lower=lower_bound if np.isfinite(lower_bound) else None,
upper=upper_bound if np.isfinite(upper_bound) else None,
initval=initval,
)

return TruncatedDist


def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior:
"""Make a default prior from bounds.
Parameters
----------
bounds
The (lower, upper) bounds for the default prior.
Returns
-------
A bmb.Prior object representing the default prior for the provided bounds.
"""
lower, upper = bounds
if np.isinf(lower) and np.isinf(upper):
return bmb.Prior("Normal", mu=0.0, sigma=2.0)
elif np.isinf(lower) and not np.isinf(upper):
return bmb.Prior("TruncatedNormal", mu=upper, upper=upper, sigma=2.0)
elif not np.isinf(lower) and np.isinf(upper):
if lower == 0:
return bmb.Prior("HalfNormal", sigma=2.0)
return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0)
else:
return bmb.Prior(name="Uniform", lower=lower, upper=upper)
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from hssm import show_defaults


def test_show_defaults():
print(show_defaults("ddm", None))
print(show_defaults("ddm", "analytical"))
Loading

0 comments on commit 01f0cd9

Please sign in to comment.