diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index 1f0b2f10..d3ed7740 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -452,6 +452,9 @@ def logp(data, *dist_params): # pylint: disable=E0213 lapse_logp = lapse_func(data[:, 0].eval()) # AF-TODO potentially apply clipping here logp = loglik(data, *dist_params, *extra_fields) + # Ensure that non-decision time is always smaller than rt. + # Assuming that the non-decision time parameter is always named "t". + logp = ensure_positive_ndt(data, logp, list_params, dist_params) logp = pt.log( (1.0 - p_outlier) * pt.exp(logp) + p_outlier * pt.exp(lapse_logp) @@ -459,15 +462,16 @@ def logp(data, *dist_params): # pylint: disable=E0213 ) else: logp = loglik(data, *dist_params, *extra_fields) + # Ensure that non-decision time is always smaller than rt. + # Assuming that the non-decision time parameter is always named "t". + logp = ensure_positive_ndt(data, logp, list_params, dist_params) if bounds is not None: logp = apply_param_bounds_to_loglik( logp, list_params, *dist_params, bounds=bounds ) - # Ensure that non-decision time is always smaller than rt. - # Assuming that the non-decision time parameter is always named "t". - return ensure_positive_ndt(data, logp, list_params, dist_params) + return logp return SSMDistribution diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index c3209362..f0dc440f 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -5,8 +5,12 @@ """ import math +from pathlib import Path +from itertools import product import numpy as np +import pymc as pm +import pytensor.tensor as pt import pytest from numpy.random import rand @@ -15,6 +19,7 @@ # pylint: disable=C0413 from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox +from hssm.distribution_utils import make_likelihood_callable hssm.set_floatX("float32") @@ -121,3 +126,85 @@ def test_bbox(data_ddm): logp_ddm_sdv_bbox(data, *true_values_sdv), decimal=4, ) + + +cav_data = hssm.load_data("cavanagh_theta") +cav_data_numpy = cav_data[["rt", "response"]].values +param_matrix = product( + (0.0, 0.01, 0.05, 0.5), ("analytical", "approx_differentiable", "blackbox") +) + + +@pytest.mark.parametrize("p_outlier, loglik_kind", param_matrix) +def test_lapse_distribution_cav(p_outlier, loglik_kind): + true_values = (0.5, 1.5, 0.5, 0.5) + v, a, z, t = true_values + + model = hssm.HSSM( + model="ddm", + data=cav_data, + p_outlier=p_outlier, + loglik_kind=loglik_kind, + loglik=Path(__file__).parent / "fixtures" / "ddm.onnx" + if loglik_kind == "approx_differentiable" + else None, + prior_settings=None, # Avoid unnecessary computation + lapse=None + if p_outlier == 0.0 + else hssm.Prior("Uniform", lower=0.0, upper=10.0), + ) + distribution = ( + model.model_distribution.dist(v=v, a=a, z=z, t=t, p_outlier=p_outlier) + if p_outlier > 0 + else model.model_distribution.dist(v=v, a=a, z=z, t=t) + ) + + cav_data_numpy = cav_data[["rt", "response"]].values + + # Convert to float32 if blackbox loglik is used + # This is necessary because the blackbox likelihood function logp_ddm_bbox is + # does not go through any PyTensor function standalone so does not respect the + # floatX setting + + # This step is not necessary for HSSM as a whole because the likelihood function + # will be part of a PyTensor graph so the floatX setting will be respected + cav_data_numpy = ( + cav_data_numpy.astype("float32") + if loglik_kind == "blackbox" + else cav_data_numpy + ) + + model_logp = pm.logp(distribution, cav_data_numpy).eval() + + if loglik_kind == "analytical": + logp_func = logp_ddm + elif loglik_kind == "approx_differentiable": + logp_func = make_likelihood_callable( + loglik=Path(__file__).parent / "fixtures" / "ddm.onnx", + loglik_kind="approx_differentiable", + backend="pytensor", + params_is_reg=[False] * 4, + ) + else: + logp_func = logp_ddm_bbox + + manual_logp = logp_func(cav_data_numpy, *true_values) + if p_outlier == 0.0: + manual_logp = pt.where( + pt.sub(cav_data_numpy[:, 0], t) <= 1e-15, -66.1, manual_logp + ).eval() + np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4) + return + + manual_logp = pt.where( + pt.sub(cav_data_numpy[:, 0], t) <= 1e-15, + -66.1, + manual_logp, + ) + manual_logp = pt.log( + (1 - p_outlier) * pt.exp(manual_logp) + + p_outlier + * pt.exp(pm.logp(pm.Uniform.dist(lower=0.0, upper=10.0), cav_data_numpy[:, 0])) + ).eval() + + np.testing.assert_almost_equal(model_logp, manual_logp, decimal=4)