From 42f02a1dd6b63068597144442fdfc0a259c48ce3 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 4 Jun 2024 14:01:11 -0400 Subject: [PATCH] Adjust the order in which ensure_positive_ndt is applied --- src/hssm/distribution_utils/dist.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index 1f0b2f10..d3ed7740 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -452,6 +452,9 @@ def logp(data, *dist_params): # pylint: disable=E0213 lapse_logp = lapse_func(data[:, 0].eval()) # AF-TODO potentially apply clipping here logp = loglik(data, *dist_params, *extra_fields) + # Ensure that non-decision time is always smaller than rt. + # Assuming that the non-decision time parameter is always named "t". + logp = ensure_positive_ndt(data, logp, list_params, dist_params) logp = pt.log( (1.0 - p_outlier) * pt.exp(logp) + p_outlier * pt.exp(lapse_logp) @@ -459,15 +462,16 @@ def logp(data, *dist_params): # pylint: disable=E0213 ) else: logp = loglik(data, *dist_params, *extra_fields) + # Ensure that non-decision time is always smaller than rt. + # Assuming that the non-decision time parameter is always named "t". + logp = ensure_positive_ndt(data, logp, list_params, dist_params) if bounds is not None: logp = apply_param_bounds_to_loglik( logp, list_params, *dist_params, bounds=bounds ) - # Ensure that non-decision time is always smaller than rt. - # Assuming that the non-decision time parameter is always named "t". - return ensure_positive_ndt(data, logp, list_params, dist_params) + return logp return SSMDistribution