Skip to content

Commit

Permalink
Merge pull request #270 from lnccbrown/252-support-for-arbitrary-addi…
Browse files Browse the repository at this point in the history
…tional-fields

Support for arbitrary additional fields
  • Loading branch information
digicosmos86 authored Sep 1, 2023
2 parents 9ee29a3 + 85979dd commit 5a2724f
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Config:
loglik: LogLik | None = None
backend: Literal["jax", "pytensor"] | None = None
rv: RandomVariable | None = None
extra_fields: list[str] | None = None
# Fields with dictionaries are automatically deepcopied
default_priors: dict[str, ParamSpec] = field(default_factory=dict)
bounds: dict[str, tuple[float, float]] = field(default_factory=dict)
Expand Down Expand Up @@ -162,3 +163,4 @@ class ModelConfig:
bounds: dict[str, tuple[float, float]] = field(default_factory=dict)
backend: Literal["jax", "pytensor"] | None = None
rv: RandomVariable | None = None
extra_fields: list[str] | None = None
14 changes: 14 additions & 0 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class LoglikConfig(TypedDict):
backend: Optional[Literal["jax", "pytensor"]]
default_priors: dict[str, ParamSpec]
bounds: dict[str, tuple[float, float]]
extra_fields: Optional[list[str]]


LoglikConfigs = dict[LoglikKind, LoglikConfig]
Expand Down Expand Up @@ -73,6 +74,7 @@ class DefaultConfig(TypedDict):
"initval": 0.1,
},
},
"extra_fields": None,
},
"approx_differentiable": {
"loglik": "ddm.onnx",
Expand All @@ -84,6 +86,7 @@ class DefaultConfig(TypedDict):
"z": (0.0, 1.0),
"t": (0.0, 2.0),
},
"extra_fields": None,
},
"blackbox": {
"loglik": logp_ddm_bbox,
Expand All @@ -96,6 +99,7 @@ class DefaultConfig(TypedDict):
"initval": 0.1,
},
},
"extra_fields": None,
},
},
},
Expand All @@ -114,6 +118,7 @@ class DefaultConfig(TypedDict):
"initval": 0.1,
},
},
"extra_fields": None,
},
"approx_differentiable": {
"loglik": "ddm_sdv.onnx",
Expand All @@ -126,6 +131,7 @@ class DefaultConfig(TypedDict):
"t": (0.0, 2.0),
"sv": (0.0, 1.0),
},
"extra_fields": None,
},
"blackbox": {
"loglik": logp_ddm_sdv_bbox,
Expand All @@ -138,6 +144,7 @@ class DefaultConfig(TypedDict):
"initval": 0.1,
},
},
"extra_fields": None,
},
},
},
Expand All @@ -156,6 +163,7 @@ class DefaultConfig(TypedDict):
"initval": 0.1,
},
},
"extra_fields": None,
}
},
},
Expand All @@ -174,6 +182,7 @@ class DefaultConfig(TypedDict):
"t": (0.001, 2.0),
"theta": (-0.1, 1.3),
},
"extra_fields": None,
},
},
},
Expand All @@ -192,6 +201,7 @@ class DefaultConfig(TypedDict):
"alpha": (1.0, 2.0),
"t": (1e-3, 2.0),
},
"extra_fields": None,
},
},
},
Expand All @@ -210,6 +220,7 @@ class DefaultConfig(TypedDict):
"g": (-1.0, 1.0),
"t": (1e-3, 2.0),
},
"extra_fields": None,
},
},
},
Expand All @@ -229,6 +240,7 @@ class DefaultConfig(TypedDict):
"alpha": (0.31, 4.99),
"beta": (0.31, 6.99),
},
"extra_fields": None,
},
},
},
Expand All @@ -250,6 +262,7 @@ class DefaultConfig(TypedDict):
"ndt": (0.0, 2.0),
"theta": (-0.1, 1.45),
},
"extra_fields": None,
},
},
},
Expand All @@ -268,6 +281,7 @@ class DefaultConfig(TypedDict):
"a": (0.3, 2.5),
"t": (0.0, 2.0),
},
"extra_fields": None,
},
},
},
Expand Down
56 changes: 48 additions & 8 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import logging
from os import PathLike
from typing import Any, Callable, Type
from typing import Any, Callable, Iterable, Type

import bambi as bmb
import numpy as np
Expand Down Expand Up @@ -199,7 +199,14 @@ def rng_fn(
if not np.isscalar(size):
size = np.squeeze(size)

arg_arrays = [np.asarray(arg) for arg in args]
num_params = len(cls._list_params)

# TODO: We need to figure out what to do with extra_fields when
# doing posterior predictive sampling. Right now nothing is done.
if num_params < len(args):
arg_arrays = [np.asarray(arg) for arg in args[:num_params]]
else:
arg_arrays = [np.asarray(arg) for arg in args]

p_outlier = None

Expand Down Expand Up @@ -299,6 +306,7 @@ def make_distribution(
list_params: list[str],
bounds: dict | None = None,
lapse: bmb.Prior | None = None,
extra_fields: list[np.ndarray] | None = None,
) -> Type[pm.Distribution]:
"""Make a `pymc.Distribution`.
Expand All @@ -325,6 +333,9 @@ def make_distribution(
Example: {"parameter": (lower_boundary, upper_boundary)}.
lapse : optional
A bmb.Prior object representing the lapse distribution.
extra_fields : optional
An optional list of arrays that are stored in the class created and will be
used in likelihood calculation. Defaults to None.
Returns
-------
Expand All @@ -337,13 +348,13 @@ def make_distribution(
if list_params[-1] != "p_outlier":
list_params.append("p_outlier")

data = pt.dvector()
data_vector = pt.dvector()
lapse_logp = pm.logp(
get_distribution_from_prior(lapse).dist(**lapse.args),
data,
data_vector,
)
lapse_func = pytensor.function(
[data],
[data_vector],
lapse_logp,
)

Expand All @@ -356,29 +367,39 @@ class SSMDistribution(pm.Distribution):
# NOTE: rv_op is an INSTANCE of RandomVariable
rv_op = random_variable()
params = list_params
_extra_fields = extra_fields

@classmethod
def dist(cls, **kwargs): # pylint: disable=arguments-renamed
dist_params = [
pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params
]
if cls._extra_fields:
dist_params += [pm.floatX(field) for field in cls._extra_fields]
other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params}
return super().dist(dist_params, **other_kwargs)

def logp(data, *dist_params): # pylint: disable=E0213
num_params = len(list_params)
extra_fields: Iterable[np.ndarray] = []

if num_params < len(dist_params):
extra_fields = dist_params[num_params:]
dist_params = dist_params[:num_params]

if list_params[-1] == "p_outlier":
p_outlier = dist_params[-1]
dist_params = dist_params[:-1]
lapse_logp = lapse_func(data[:, 0].eval())

logp = loglik(data, *dist_params)
logp = loglik(data, *dist_params, *extra_fields)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
+ p_outlier * pt.exp(lapse_logp)
+ 1e-29
)
else:
logp = loglik(data, *dist_params)
logp = loglik(data, *dist_params, *extra_fields)

if bounds is None:
return logp
Expand All @@ -398,6 +419,7 @@ def make_distribution_from_onnx(
bounds: dict | None = None,
params_is_reg: list[bool] | None = None,
lapse: bmb.Prior | None = None,
extra_fields: list[np.ndarray] | None = None,
) -> Type[pm.Distribution]:
"""Make a PyMC distribution from an ONNX model.
Expand Down Expand Up @@ -429,6 +451,9 @@ def make_distribution_from_onnx(
corresponding position in `list_params` is a regression.
lapse : optional
A bmb.Prior object representing the lapse distribution.
extra_fields : optional
An optional list of arrays that are stored in the class created and will be
used in likelihood calculation. Defaults to None.
Returns
-------
Expand All @@ -446,21 +471,30 @@ def make_distribution_from_onnx(
list_params,
bounds=bounds,
lapse=lapse,
extra_fields=extra_fields,
)
if backend == "jax":
if params_is_reg is None:
params_is_reg = [False for param in list_params if param != "p_outlier"]

# Extra fields are passed to the likelihood functions as vectors
# They do not need to be broadcast, so param_is_reg is padded with True
if extra_fields:
params_is_reg += [True for _ in extra_fields]

logp, logp_grad, logp_nojit = make_jax_logp_funcs_from_onnx(
onnx_model,
params_is_reg,
)
lan_logp_jax = make_jax_logp_ops(logp, logp_grad, logp_nojit)

return make_distribution(
rv,
lan_logp_jax,
list_params,
bounds=bounds,
lapse=lapse,
extra_fields=extra_fields,
)

raise ValueError("Currently only 'pytensor' and 'jax' backends are supported.")
Expand Down Expand Up @@ -581,6 +615,7 @@ def make_distribution_from_blackbox(
loglik: Callable,
list_params: list[str],
bounds: dict | None = None,
extra_fields: list[np.ndarray] | None = None,
) -> Type[pm.Distribution]:
"""Make a `pymc.Distribution`.
Expand All @@ -604,6 +639,9 @@ def make_distribution_from_blackbox(
bounds : optional
A dictionary with parameters as keys (a string) and its boundaries as values.
Example: {"parameter": (lower_boundary, upper_boundary)}.
extra_fields : optional
An optional list of arrays that are stored in the class created and will be
used in likelihood calculation. Defaults to None.
Returns
-------
Expand All @@ -612,4 +650,6 @@ def make_distribution_from_blackbox(
"""
blackbox_op = make_blackbox_op(loglik)

return make_distribution(rv, blackbox_op, list_params, bounds)
return make_distribution(
rv, blackbox_op, list_params, bounds, extra_fields=extra_fields
)
Loading

0 comments on commit 5a2724f

Please sign in to comment.