Skip to content

Commit

Permalink
Merge pull request #452 from lnccbrown/fix-ensure-positive-ndt
Browse files Browse the repository at this point in the history
Fix ensure positive ndt
  • Loading branch information
digicosmos86 authored Jun 7, 2024
2 parents a135502 + 0c98c02 commit 3f7c518
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,26 @@ def logp(data, *dist_params): # pylint: disable=E0213
lapse_logp = lapse_func(data[:, 0].eval())
# AF-TODO potentially apply clipping here
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
# Assuming that the non-decision time parameter is always named "t".
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
+ p_outlier * pt.exp(lapse_logp)
+ 1e-29
)
else:
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
# Assuming that the non-decision time parameter is always named "t".
logp = ensure_positive_ndt(data, logp, list_params, dist_params)

if bounds is not None:
logp = apply_param_bounds_to_loglik(
logp, list_params, *dist_params, bounds=bounds
)

# Ensure that non-decision time is always smaller than rt.
# Assuming that the non-decision time parameter is always named "t".
return ensure_positive_ndt(data, logp, list_params, dist_params)
return logp

return SSMDistribution

Expand Down
87 changes: 87 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
"""

import math
from pathlib import Path
from itertools import product

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest
from numpy.random import rand

Expand All @@ -15,6 +19,7 @@
# pylint: disable=C0413
from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")

Expand Down Expand Up @@ -121,3 +126,85 @@ def test_bbox(data_ddm):
logp_ddm_sdv_bbox(data, *true_values_sdv),
decimal=4,
)


cav_data = hssm.load_data("cavanagh_theta")
cav_data_numpy = cav_data[["rt", "response"]].values
param_matrix = product(
(0.0, 0.01, 0.05, 0.5), ("analytical", "approx_differentiable", "blackbox")
)


@pytest.mark.parametrize("p_outlier, loglik_kind", param_matrix)
def test_lapse_distribution_cav(p_outlier, loglik_kind):
true_values = (0.5, 1.5, 0.5, 0.5)
v, a, z, t = true_values

model = hssm.HSSM(
model="ddm",
data=cav_data,
p_outlier=p_outlier,
loglik_kind=loglik_kind,
loglik=Path(__file__).parent / "fixtures" / "ddm.onnx"
if loglik_kind == "approx_differentiable"
else None,
prior_settings=None, # Avoid unnecessary computation
lapse=None
if p_outlier == 0.0
else hssm.Prior("Uniform", lower=0.0, upper=10.0),
)
distribution = (
model.model_distribution.dist(v=v, a=a, z=z, t=t, p_outlier=p_outlier)
if p_outlier > 0
else model.model_distribution.dist(v=v, a=a, z=z, t=t)
)

cav_data_numpy = cav_data[["rt", "response"]].values

# Convert to float32 if blackbox loglik is used
# This is necessary because the blackbox likelihood function logp_ddm_bbox is
# does not go through any PyTensor function standalone so does not respect the
# floatX setting

# This step is not necessary for HSSM as a whole because the likelihood function
# will be part of a PyTensor graph so the floatX setting will be respected
cav_data_numpy = (
cav_data_numpy.astype("float32")
if loglik_kind == "blackbox"
else cav_data_numpy
)

model_logp = pm.logp(distribution, cav_data_numpy).eval()

if loglik_kind == "analytical":
logp_func = logp_ddm
elif loglik_kind == "approx_differentiable":
logp_func = make_likelihood_callable(
loglik=Path(__file__).parent / "fixtures" / "ddm.onnx",
loglik_kind="approx_differentiable",
backend="pytensor",
params_is_reg=[False] * 4,
)
else:
logp_func = logp_ddm_bbox

manual_logp = logp_func(cav_data_numpy, *true_values)
if p_outlier == 0.0:
manual_logp = pt.where(
pt.sub(cav_data_numpy[:, 0], t) <= 1e-15, -66.1, manual_logp
).eval()
np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4)
return

manual_logp = pt.where(
pt.sub(cav_data_numpy[:, 0], t) <= 1e-15,
-66.1,
manual_logp,
)
manual_logp = pt.log(
(1 - p_outlier) * pt.exp(manual_logp)
+ p_outlier
* pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), cav_data_numpy[:, 0]))
).eval()

np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4)

0 comments on commit 3f7c518

Please sign in to comment.