Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix InferenceData #522

Merged
merged 13 commits into from
Aug 5, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
rev: v0.5.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = "~5.16"
pymc = ">=5.16.2,<5.17.0"
arviz = "^0.18.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
huggingface-hub = "^0.23.0"
bambi = "~0.14"
bambi = ">=0.14.0,<0.15.0"
numpyro = "^0.15.0"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"
jax = { version = "^0.4.25", extras = ["cuda12"], optional = true }

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest = "^8.3.1"
mypy = "^1.10.1"
pre-commit = "^2.20.0"
jupyterlab = "^4.2.3"
ipykernel = "^6.29.4"
ipywidgets = "^8.1.2"
ruff = "^0.5.2"
ruff = "^0.5.4"
graphviz = "^0.20.3"
pytest-xdist = "^3.6.1"
onnxruntime = "^1.17.1"
Expand Down
4 changes: 2 additions & 2 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> s
output += _show_defaults_helper(model, loglik_kind)

else:
for loglik_kind in model_config["likelihoods"].keys():
output += _show_defaults_helper(model, loglik_kind)
for loglik_kind_ in model_config["likelihoods"]:
output += _show_defaults_helper(model, loglik_kind_)
output.append("")

output = output[:-1]
Expand Down
5 changes: 5 additions & 0 deletions src/hssm/distribution_utils/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def make_jax_logp_funcs_from_onnx(
)

scalars_only = all(not is_reg for is_reg in params_is_reg)
print("scalars only: ", scalars_only)

def logp(*inputs) -> jnp.ndarray:
"""Compute the log-likelihood.
Expand All @@ -76,11 +77,14 @@ def logp(*inputs) -> jnp.ndarray:
The element-wise log-likelihoods.
"""
# Makes a matrix to feed to the LAN model
print("scalars only: ", scalars_only)
print("params only: ", params_only)
if params_only:
input_vector = jnp.array(inputs)
else:
data = inputs[0]
dist_params = inputs[1:]
print([inp.shape for inp in dist_params])
param_vector = jnp.array([inp.squeeze() for inp in dist_params])
if param_vector.shape[-1] == 1:
param_vector = param_vector.squeeze(axis=-1)
Expand All @@ -89,6 +93,7 @@ def logp(*inputs) -> jnp.ndarray:
return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze()

if params_only and scalars_only:
print("passing scalars only case")
logp_vec = lambda *inputs: logp(*inputs).reshape((1,))
return jit(logp_vec), jit(grad(logp)), logp_vec

Expand Down
127 changes: 62 additions & 65 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
_make_default_prior,
)
from hssm.utils import (
HSSMModelGraph,
_get_alias_dict,
_print_prior,
_process_param_in_kwargs,
Expand Down Expand Up @@ -404,7 +403,8 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
# _logger.info(self.pymc_model.initial_point())
self.model.build()
_logger.info(self.pymc_model.initial_point())

if process_initvals:
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
Expand All @@ -431,6 +431,7 @@ def sample(
) = None,
init: str | None = None,
initvals: str | dict | None = None,
include_response_params: bool = False,
**kwargs,
) -> az.InferenceData | pm.Approximation:
"""Perform sampling using the `fit` method via bambi.Model.
Expand All @@ -452,6 +453,10 @@ def sample(
values for parameters of the model, or a string "map" to use initialization
at the MAP estimate. If "map" is used, the MAP estimate will be computed if
not already attached to the base class from prior call to 'find_MAP`.
include_response_params: optional
Include parameters of the response distribution in the output. These usually
take more space than other parameters as there's one of them per
observation. Defaults to False.
kwargs
Other arguments passed to bmb.Model.fit(). Please see [here]
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
Expand Down Expand Up @@ -538,35 +543,40 @@ def sample(
# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if kwargs.get("jitter", None):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
else:
if "jitter" in kwargs:
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
if "nuts_numpyro_kwargs" in kwargs:
if kwargs["nuts_sampler_kwargs"].get("jitter"):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["nuts_sampler_kwargs"]["jitter"] = False
else:
kwargs["nuts_sampler_kwargs"] = {"jitter": False}

if "include_mean" not in kwargs:
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
if self._inference_obj is not None:
_logger.warning(
"The model has already been sampled. Overwriting the previous "
+ "inference object. Any previous reference to the inference object "
+ "will still point to the old object."
)

if "nuts_sampler" not in kwargs:
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
kwargs["nuts_sampler"] = (
"pymc" if sampler == "mcmc" else sampler.split("_")[1]
)

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
inference_method=(
"mcmc"
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]
else sampler
),
init=init,
include_response_params=include_response_params,
**kwargs,
)

# The parent was previously not part of deterministics --> compute it via
Expand Down Expand Up @@ -596,7 +606,7 @@ def sample_posterior_predictive(
data: pd.DataFrame | None = None,
inplace: bool = True,
include_group_specific: bool = True,
kind: Literal["pps", "mean"] = "pps",
kind: Literal["response", "response_params"] = "response",
draws: int | float | list[int] | np.ndarray | None = None,
safe_mode: bool = True,
) -> az.InferenceData | None:
Expand All @@ -619,11 +629,12 @@ def sample_posterior_predictive(
Otherwise, predictions are made with common effects only (i.e. group-
specific are set to zero), by default True.
kind: optional
Indicates the type of prediction required. Can be `"mean"` or `"pps"`. The
first returns draws from the posterior distribution of the mean, while the
latter returns the draws from the posterior predictive distribution
(i.e. the posterior probability distribution for a new observation).
Defaults to `"pps"`.
Indicates the type of prediction required. Can be `"response_params"` or
`"response"`. The first returns draws from the posterior distribution of the
likelihood parameters, while the latter returns the draws from the posterior
predictive distribution (i.e. the posterior probability distribution for a
new observation) in addition to the posterior distribution. Defaults to
"response_params".
draws: optional
The number of samples to draw from the posterior predictive distribution
from each chain.
Expand Down Expand Up @@ -697,8 +708,8 @@ def sample_posterior_predictive(
# Reassign posterior to sub-sampled version
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))

if kind == "pps":
# If we run kind == 'pps' we actually run the observation RV
if kind == "response":
# If we run kind == 'response' we actually run the observation RV
if safe_mode:
# safe mode splits the draws into chunks of 10 to avoid
# memory issues (TODO: Figure out the source of memory issues)
Expand Down Expand Up @@ -760,8 +771,8 @@ def sample_posterior_predictive(
return self.model.predict(
idata_copy, kind, data, False, include_group_specific
)
elif kind == "mean":
# If kind == 'mean', we don't need to run the RV directly,
elif kind == "response_params":
# If kind == 'response_params', we don't need to run the RV directly,
# there shouldn't really be any significant memory issues here,
# we can simply ignore settings, since the computational overhead
# should be very small --> nudges user towards good outputs.
Expand All @@ -772,6 +783,8 @@ def sample_posterior_predictive(
return self.model.predict(
idata, kind, data, inplace, include_group_specific
)
else:
raise ValueError("`kind` must be either 'response' or 'response_params'.")

def plot_posterior_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
"""Produce a posterior predictive plot.
Expand Down Expand Up @@ -884,8 +897,7 @@ def response_str(self) -> str:
"""Return the response variable names in string format."""
return ",".join(self.response)

# NOTE: can't annotate return type because the graphviz dependency is
# optional
# NOTE: can't annotate return type because the graphviz dependency is optional
def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"):
"""Produce a graphviz Digraph from a built HSSM model.

Expand Down Expand Up @@ -916,30 +928,22 @@ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png")
-------
graphviz.Graph
The graph

Note
----
The code is largely copied from
https://github.com/bambinos/bambi/blob/main/bambi/models.py
Credit for the code goes to Bambi developers.
"""
self.model._check_built()
graph = self.model.graph(formatting, name, figsize, dpi, fmt)

graphviz = HSSMModelGraph(
model=self.pymc_model, parent=self._parent_param
).make_graph(formatting=formatting, response_str=self.response_str)
parent_param = self._parent_param
if parent_param.is_regression:
return graph

width, height = (None, None) if figsize is None else figsize

if name is not None:
graphviz_ = graphviz.copy()
graphviz_.graph_attr.update(size=f"{width},{height}!")
graphviz_.graph_attr.update(dpi=str(dpi))
graphviz_.render(filename=name, format=fmt, cleanup=True)

return graphviz_
# Modify the graph
# 1. Remove all nodes and edges related to `{parent}_mean`:
graph.body = [
item for item in graph.body if f"{parent_param.name}_mean" not in item
]
# 2. Add a new edge from parent to response
graph.edge(parent_param.name, self.response_str)

return graphviz
return graph

def plot_trace(
self,
Expand Down Expand Up @@ -1077,10 +1081,9 @@ def __repr__(self) -> str:
for param in self.params.values():
if param.name == "p_outlier":
continue
name = self.response_c if param.is_parent else param.name
output.append(f"{param.name}:")

component = self.model.components[name]
component = self.model.components[param.name]

# Regression case:
if param.is_regression:
Expand Down Expand Up @@ -1548,14 +1551,8 @@ def _get_deterministic_var_names(self, idata) -> list[str]:
if param.is_regression
]

# Handle specific case where parent is not explictly in traces
if ("~" + self._parent in var_names) and (
self._parent not in idata.posterior.data_vars
):
var_names.remove("~" + self._parent)

if f"{self.response_str}_mean" in idata["posterior"].data_vars:
var_names.append(f"~{self.response_str}_mean")
if f"{self._parent}_mean" in idata["posterior"].data_vars:
var_names.append(f"~{self._parent}_mean")

return var_names

Expand Down
12 changes: 9 additions & 3 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
"""
self._ensure_not_converted(context="prior")

if not self.is_regression:
# If no regression, or the parameter is the parent and does not have a
# formula attached (in which case it still gets a trial wise deterministic)
# do nothing
if not self.is_regression or (self.is_parent and self.formula is None):
return

override_priors = {}
Expand Down Expand Up @@ -213,7 +216,10 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
self._ensure_not_converted(context="prior")
assert self.name is not None

if not self.is_regression:
# If no regression, or the parameter is the parent and does not have a
# formula attached (in which case it still gets a trial wise deterministic)
# do nothing
if not self.is_regression or (self.is_parent and self.formula is None):
return

override_priors = {}
Expand Down Expand Up @@ -380,7 +386,7 @@ def is_regression(self) -> bool:
bool
A boolean that indicates if a regression is specified.
"""
return self.formula is not None
return self.formula is not None or self._is_parent

@property
def is_parent(self) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions src/hssm/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _xarray_to_df(
We make the following assumptions:
1. The inference data always has a posterior predictive group with a `rt,response`
variable.
2. This variable always has four dimensions: `chain`, `draw`, `rt,response_obs`,
2. This variable always has four dimensions: `chain`, `draw`, `__obs__`,
and `rt,response_dim`.

Parameters
Expand All @@ -46,10 +46,10 @@ def _xarray_to_df(

# Convert the posterior samples to a dataframe
stacked = (
sampled_posterior.stack({"obs": ["chain", "draw", f"{response_str}_obs"]})
sampled_posterior.stack({"obs": ["chain", "draw", "__obs__"]})
.transpose()
.to_pandas()
.rename_axis(index={f"{response_str}_obs": "obs_n"})
.rename_axis(index={"__obs__": "obs_n"})
.sort_index(axis=0, level=["chain", "draw", "obs_n"])
)

Expand Down Expand Up @@ -141,7 +141,7 @@ def _get_plotting_df(
posterior.insert(0, "observed", "predicted")
return posterior

if extra_dims and idata_posterior[f"{response_str}_obs"].size != data.shape[0]:
if extra_dims and idata_posterior["__obs__"].size != data.shape[0]:
raise ValueError(
"The number of observations in the data and the number of posterior "
+ "samples are not equal."
Expand Down
Loading
Loading