Skip to content

Commit

Permalink
Add tests for lapse distribution using cav data
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Jun 4, 2024
1 parent 42f02a1 commit 0c98c02
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
"""

import math
from pathlib import Path
from itertools import product

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest
from numpy.random import rand

Expand All @@ -15,6 +19,7 @@
# pylint: disable=C0413
from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")

Expand Down Expand Up @@ -121,3 +126,85 @@ def test_bbox(data_ddm):
logp_ddm_sdv_bbox(data, *true_values_sdv),
decimal=4,
)


cav_data = hssm.load_data("cavanagh_theta")
cav_data_numpy = cav_data[["rt", "response"]].values
param_matrix = product(
(0.0, 0.01, 0.05, 0.5), ("analytical", "approx_differentiable", "blackbox")
)


@pytest.mark.parametrize("p_outlier, loglik_kind", param_matrix)
def test_lapse_distribution_cav(p_outlier, loglik_kind):
true_values = (0.5, 1.5, 0.5, 0.5)
v, a, z, t = true_values

model = hssm.HSSM(
model="ddm",
data=cav_data,
p_outlier=p_outlier,
loglik_kind=loglik_kind,
loglik=Path(__file__).parent / "fixtures" / "ddm.onnx"
if loglik_kind == "approx_differentiable"
else None,
prior_settings=None, # Avoid unnecessary computation
lapse=None
if p_outlier == 0.0
else hssm.Prior("Uniform", lower=0.0, upper=10.0),
)
distribution = (
model.model_distribution.dist(v=v, a=a, z=z, t=t, p_outlier=p_outlier)
if p_outlier > 0
else model.model_distribution.dist(v=v, a=a, z=z, t=t)
)

cav_data_numpy = cav_data[["rt", "response"]].values

# Convert to float32 if blackbox loglik is used
# This is necessary because the blackbox likelihood function logp_ddm_bbox is
# does not go through any PyTensor function standalone so does not respect the
# floatX setting

# This step is not necessary for HSSM as a whole because the likelihood function
# will be part of a PyTensor graph so the floatX setting will be respected
cav_data_numpy = (
cav_data_numpy.astype("float32")
if loglik_kind == "blackbox"
else cav_data_numpy
)

model_logp = pm.logp(distribution, cav_data_numpy).eval()

if loglik_kind == "analytical":
logp_func = logp_ddm
elif loglik_kind == "approx_differentiable":
logp_func = make_likelihood_callable(
loglik=Path(__file__).parent / "fixtures" / "ddm.onnx",
loglik_kind="approx_differentiable",
backend="pytensor",
params_is_reg=[False] * 4,
)
else:
logp_func = logp_ddm_bbox

manual_logp = logp_func(cav_data_numpy, *true_values)
if p_outlier == 0.0:
manual_logp = pt.where(
pt.sub(cav_data_numpy[:, 0], t) <= 1e-15, -66.1, manual_logp
).eval()
np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4)
return

manual_logp = pt.where(
pt.sub(cav_data_numpy[:, 0], t) <= 1e-15,
-66.1,
manual_logp,
)
manual_logp = pt.log(
(1 - p_outlier) * pt.exp(manual_logp)
+ p_outlier
* pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), cav_data_numpy[:, 0]))
).eval()

np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4)

0 comments on commit 0c98c02

Please sign in to comment.