From 64873e6f5161e4678f89be0c1760bd721cf7ccd2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 19 Sep 2024 15:14:36 -0400 Subject: [PATCH 01/11] Add size enumeration and log domain check functions for RTs --- src/hssm/likelihoods/analytical.py | 58 ++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index d626d26f..aeda7268 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -5,6 +5,7 @@ https://gist.github.com/sammosummo/c1be633a74937efaca5215da776f194b. """ +from enum import Enum from typing import Type import numpy as np @@ -19,6 +20,63 @@ LOGP_LB = pm.floatX(-66.1) +# define enum large and small +class _Size(Enum): + LARGE = 1 + SMALL = 0 + + +def _log_bound_error_msg(bound: float, size: _Size) -> str: + bound_formula = "1 / (π * err)" if size.LARGE == 1 else "1 / (8 * err^2 * π)" + msg = ( + f"RTs must be less than {bound} = {bound_formula} " + f"for the {size.name.lower()} expansion." + ) + return msg + + +def get_bound(err: float, size: _Size) -> float: + """Get the bound for kl/ks log operations. + + Parameters + ---------- + err + Error bound. + size + Option for type of k terms, large (1) or small (0). + + Returns + ------- + float + The bound for the log operations. + """ + return 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) + + +def check_rt_log_domain( + rt: np.ndarray, err: float, size: _Size, clip_values=True +) -> np.ndarray: + """Check if the RTs are within correct domain for log operations for kl/ks. + + Parameters + ---------- + rt + Flipped RTs. + err + Error bound. + size + Option for type of k terms, large (1) or small (0). + """ + bound = 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) + epsilon = np.finfo(float).eps + + if clip_values: + return np.clip(rt, epsilon, bound * (1 - epsilon)) + + if not np.all(rt < bound): + raise ValueError(_log_bound_error_msg(bound, size)) + + def k_small(rt: np.ndarray, err: float) -> np.ndarray: """Determine number of terms needed for small-t expansion. From 571080841c3da0a687ce9ad62db60a233faa7a1f Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 19 Sep 2024 15:15:05 -0400 Subject: [PATCH 02/11] Refactor k_small and k_large functions to improve log domain checks and enhance readability --- src/hssm/likelihoods/analytical.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index aeda7268..3e45a008 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -92,9 +92,15 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_small. """ - ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err)) + log_arg = 2 * np.sqrt(2 * np.pi * rt) + log_arg = check_rt_log_domain(log_arg, err, _Size.SMALL) + sqrt_arg = -2 * rt * pt.log(log_arg) * err + + ks = 2 + pt.sqrt(sqrt_arg) ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) - ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2) + + condition = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 + ks = pt.switch(condition, ks, 2) return ks @@ -114,9 +120,16 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_large. """ - kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) - kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0) - kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt))) + log_arg = np.pi * rt * err + log_arg = check_rt_log_domain(log_arg, err, _Size.LARGE) + divisor = np.pi**2 * rt + alternate = 1.0 / (np.pi * pt.sqrt(rt)) + + kl = pt.sqrt(-2 * pt.log(log_arg) / divisor) + kl = pt.max(pt.stack([kl, alternate]), axis=0) + + condition = np.pi * rt * err < 1 + kl = pt.switch(condition, kl, alternate) return kl From dfa60ae73178f0730e2c378db01cd5d779890768 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 19 Sep 2024 15:15:23 -0400 Subject: [PATCH 03/11] Add tests for numerical stability utilities in likelihood functions --- tests/test_numerical_stability_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/test_numerical_stability_utils.py diff --git a/tests/test_numerical_stability_utils.py b/tests/test_numerical_stability_utils.py new file mode 100644 index 00000000..779d553a --- /dev/null +++ b/tests/test_numerical_stability_utils.py @@ -0,0 +1,24 @@ +"""Tests for the utilities to ensure numerical stability of likelihood functions.""" + +import numpy as np +import pytest + +from hssm.likelihoods.analytical import _Size, check_rt_log_domain, _log_bound_error_msg + + +def test_check_rt_log_domain(): + err = 1e-2 + epsilon = np.finfo(float).eps + for size in _Size: + bound = 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) + rt = np.array([bound * (1 + epsilon)]) + + with pytest.raises(ValueError, match=r"^RTs must be less than"): + check_rt_log_domain(rt, err, size, clip_values=False) + + _rt = check_rt_log_domain(rt, err, size) + check_rt_log_domain(_rt, err, size, clip_values=False) + + +if __name__ == "__main__": + pytest.main([__file__]) From 8cb5a1711e96012e66e603483b66bde14cdbfc19 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Thu, 19 Sep 2024 16:36:49 -0400 Subject: [PATCH 04/11] Add return statement to check_rt_log_domain function for consistency --- src/hssm/likelihoods/analytical.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index 3e45a008..24cd747f 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -76,6 +76,8 @@ def check_rt_log_domain( if not np.all(rt < bound): raise ValueError(_log_bound_error_msg(bound, size)) + return np.array([]) + def k_small(rt: np.ndarray, err: float) -> np.ndarray: """Determine number of terms needed for small-t expansion. From 98c60690ad840ef862e1a42d39e426a3f6fd7e40 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 11:59:06 -0400 Subject: [PATCH 05/11] Add logp_ddm_combined function to compute log likelihood with optional drift rate standard deviation --- src/hssm/likelihoods/analytical.py | 61 ++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index 24cd747f..a93dc41b 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -458,6 +458,67 @@ def logp_ddm_sdv( bounds=ddm_sdv_bounds, ) + +def logp_ddm_combined(data, v, a, z, t, err, k_terms, epsilon, sv=0): + """ + Compute the log likelihood of the drift diffusion model with optional standard deviation of the drift rate. + + Parameters + ---------- + data : np.ndarray + The data array containing response times and responses. + v : float + Drift rate. + a : float + Boundary separation. + z : float + Starting point [0, 1]. + t : float + Non-decision time [0, inf). + err : float + Error bound. + k_terms : int + Number of terms to use to approximate the PDF. + epsilon : float + A small positive number to prevent division by zero or taking the log of zero. + sv : float, optional + Standard deviation of the drift rate [0, inf). Default is 0. + + Returns + ------- + np.ndarray + The log likelihood of the drift diffusion model with the standard deviation of sv. + """ + if sv == 0: + return logp_ddm(data, v, a, z, t, err, k_terms, epsilon) + + data = pt.reshape(data, (-1, 2)).astype(pytensor.config.floatX) + rt = pt.abs(data[:, 0]) + response = data[:, 1] + flip = response > 0 + a = a * 2.0 + v_flipped = pt.switch(flip, -v, v) # transform v if x is upper-bound response + z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response + rt = rt - t + negative_rt = rt < 0 + rt = pt.switch(negative_rt, epsilon, rt) + + p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) + logp = pt.switch( + rt <= epsilon, + LOGP_LB, + pt.log(p) + + ( + (a * z_flipped * sv) ** 2 + - 2 * a * v_flipped * z_flipped + - (v_flipped**2) * rt + ) + / (2 * (sv**2) * rt + 2) + - 0.5 * pt.log(sv**2 * rt + 1) + - 2 * pt.log(a), + ) + return logp + # LBA From c073dd75ec2ac10cadeba2e05d3888f3df7e5f81 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 16:01:56 -0400 Subject: [PATCH 06/11] Add mathematical constants and helper function for element-wise maximum --- src/hssm/likelihoods/analytical.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index a93dc41b..ed7d12df 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -19,6 +19,15 @@ LOGP_LB = pm.floatX(-66.1) +π = np.pi +τ = 2 * π +log_τ = pt.log(τ) +log_4 = pt.log(4) + + +def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return pt.max(pt.stack([a, b]), axis=0) + # define enum large and small class _Size(Enum): From 194c130b74d30480f0ec7c519db05d8437402965 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 16:02:30 -0400 Subject: [PATCH 07/11] Refactor k_small function to improve calculations and enhance readability --- src/hssm/likelihoods/analytical.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index ed7d12df..a8efe7ac 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -103,14 +103,14 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_small. """ - log_arg = 2 * np.sqrt(2 * np.pi * rt) - log_arg = check_rt_log_domain(log_arg, err, _Size.SMALL) - sqrt_arg = -2 * rt * pt.log(log_arg) * err + sqrt_τ_rt = pt.sqrt(τ * rt) + log_rt = pt.log(rt) + rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt) - ks = 2 + pt.sqrt(sqrt_arg) - ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) + ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2) + ks = _max(ks, pt.sqrt(rt) + 1) - condition = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 + condition = 2 * sqrt_τ_rt * err < 1 ks = pt.switch(condition, ks, 2) return ks From b79ef60ce3b5cb67d86a529705f1e16b638fdca2 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 16:05:00 -0400 Subject: [PATCH 08/11] =?UTF-8?q?Replace=20np.pi=20with=20=CF=80=20for=20c?= =?UTF-8?q?onsistency=20in=20mathematical=20expressions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/hssm/likelihoods/analytical.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index a8efe7ac..b4ce7a6f 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -59,7 +59,7 @@ def get_bound(err: float, size: _Size) -> float: float The bound for the log operations. """ - return 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) + return 1 / (π * err) if size == _Size.LARGE else 1 / (8 * err**2 * π) def check_rt_log_domain( @@ -76,7 +76,7 @@ def check_rt_log_domain( size Option for type of k terms, large (1) or small (0). """ - bound = 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) + bound = 1 / (π * err) if size == _Size.LARGE else 1 / (8 * err**2 * π) epsilon = np.finfo(float).eps if clip_values: @@ -131,15 +131,15 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_large. """ - log_arg = np.pi * rt * err - log_arg = check_rt_log_domain(log_arg, err, _Size.LARGE) - divisor = np.pi**2 * rt - alternate = 1.0 / (np.pi * pt.sqrt(rt)) + log_arg = π * rt * err + # log_arg = check_rt_log_domain(log_arg, err, _Size.LARGE) + divisor = π**2 * rt + alternate = 1.0 / (π * pt.sqrt(rt)) kl = pt.sqrt(-2 * pt.log(log_arg) / divisor) kl = pt.max(pt.stack([kl, alternate]), axis=0) - condition = np.pi * rt * err < 1 + condition = π * rt * err < 1 kl = pt.switch(condition, kl, alternate) return kl @@ -223,7 +223,7 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: c = pt.max(r, axis=0) p = pt.exp(c) * pt.sum(y * pt.exp(r - c), axis=0) # Normalize p - p = p / pt.sqrt(2 * np.pi * pt.power(tt, 3)) + p = p / pt.sqrt(2 * π * pt.power(tt, 3)) return p @@ -249,9 +249,9 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: The approximated function f(tt|0, 1, w). """ k = get_ks(k_terms, fast=False) - y = k * pt.sin(k * np.pi * w) - r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2 - p = pt.sum(y * pt.exp(r), axis=0) * np.pi + y = k * pt.sin(k * π * w) + r = -pt.power(k, 2) * pt.power(π, 2) * tt / 2 + p = pt.sum(y * pt.exp(r), axis=0) * π return p From 1eef96aa6204b39c9672b9bb700245b28faa9002 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 16:25:27 -0400 Subject: [PATCH 09/11] Add constants and refactor k_small/large for improved calculations --- src/hssm/likelihoods/analytical.py | 77 ++++-------------------------- 1 file changed, 9 insertions(+), 68 deletions(-) diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index b4ce7a6f..f6bef250 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -21,6 +21,7 @@ π = np.pi τ = 2 * π +log_π = pt.log(π) log_τ = pt.log(τ) log_4 = pt.log(4) @@ -28,66 +29,6 @@ def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray: return pt.max(pt.stack([a, b]), axis=0) - -# define enum large and small -class _Size(Enum): - LARGE = 1 - SMALL = 0 - - -def _log_bound_error_msg(bound: float, size: _Size) -> str: - bound_formula = "1 / (π * err)" if size.LARGE == 1 else "1 / (8 * err^2 * π)" - msg = ( - f"RTs must be less than {bound} = {bound_formula} " - f"for the {size.name.lower()} expansion." - ) - return msg - - -def get_bound(err: float, size: _Size) -> float: - """Get the bound for kl/ks log operations. - - Parameters - ---------- - err - Error bound. - size - Option for type of k terms, large (1) or small (0). - - Returns - ------- - float - The bound for the log operations. - """ - return 1 / (π * err) if size == _Size.LARGE else 1 / (8 * err**2 * π) - - -def check_rt_log_domain( - rt: np.ndarray, err: float, size: _Size, clip_values=True -) -> np.ndarray: - """Check if the RTs are within correct domain for log operations for kl/ks. - - Parameters - ---------- - rt - Flipped RTs. - err - Error bound. - size - Option for type of k terms, large (1) or small (0). - """ - bound = 1 / (π * err) if size == _Size.LARGE else 1 / (8 * err**2 * π) - epsilon = np.finfo(float).eps - - if clip_values: - return np.clip(rt, epsilon, bound * (1 - epsilon)) - - if not np.all(rt < bound): - raise ValueError(_log_bound_error_msg(bound, size)) - - return np.array([]) - - def k_small(rt: np.ndarray, err: float) -> np.ndarray: """Determine number of terms needed for small-t expansion. @@ -131,16 +72,16 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_large. """ - log_arg = π * rt * err - # log_arg = check_rt_log_domain(log_arg, err, _Size.LARGE) - divisor = π**2 * rt - alternate = 1.0 / (π * pt.sqrt(rt)) + log_rt = pt.log(rt) + sqrt_rt = pt.sqrt(rt) + log_err = pt.log(err) - kl = pt.sqrt(-2 * pt.log(log_arg) / divisor) - kl = pt.max(pt.stack([kl, alternate]), axis=0) + π_rt_err = π * rt * err + π_sqrt_rt = π * sqrt_rt - condition = π * rt * err < 1 - kl = pt.switch(condition, kl, alternate) + kl = pt.sqrt(-2 * (log_π + log_err + log_rt)) / π_sqrt_rt + kl = _max(kl, 1.0 / pt.sqrt(π_sqrt_rt)) + kl = pt.switch(π_rt_err < 1, kl, 1.0 / π_sqrt_rt) return kl From c9c590f4ecdd86f1d657e5e492a4c401ad94e09c Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 18:41:59 -0400 Subject: [PATCH 10/11] =?UTF-8?q?Add=20sqrt=5F=CF=84=20and=20refactor=20k?= =?UTF-8?q?=5Fsmall=20function=20for=20improved=20calculations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/hssm/likelihoods/analytical.py | 70 ++----------------------- tests/test_numerical_stability_utils.py | 24 --------- 2 files changed, 5 insertions(+), 89 deletions(-) delete mode 100644 tests/test_numerical_stability_utils.py diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index f6bef250..860e8f07 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -5,7 +5,6 @@ https://gist.github.com/sammosummo/c1be633a74937efaca5215da776f194b. """ -from enum import Enum from typing import Type import numpy as np @@ -21,6 +20,7 @@ π = np.pi τ = 2 * π +sqrt_τ = pt.sqrt(τ) log_π = pt.log(π) log_τ = pt.log(τ) log_4 = pt.log(4) @@ -29,6 +29,7 @@ def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray: return pt.max(pt.stack([a, b]), axis=0) + def k_small(rt: np.ndarray, err: float) -> np.ndarray: """Determine number of terms needed for small-t expansion. @@ -44,14 +45,14 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_small. """ - sqrt_τ_rt = pt.sqrt(τ * rt) + sqrt_rt = pt.sqrt(rt) log_rt = pt.log(rt) rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt) ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2) - ks = _max(ks, pt.sqrt(rt) + 1) + ks = _max(ks, sqrt_rt + 1) - condition = 2 * sqrt_τ_rt * err < 1 + condition = 2 * sqrt_τ * sqrt_rt * err < 1 ks = pt.switch(condition, ks, 2) return ks @@ -408,67 +409,6 @@ def logp_ddm_sdv( bounds=ddm_sdv_bounds, ) - -def logp_ddm_combined(data, v, a, z, t, err, k_terms, epsilon, sv=0): - """ - Compute the log likelihood of the drift diffusion model with optional standard deviation of the drift rate. - - Parameters - ---------- - data : np.ndarray - The data array containing response times and responses. - v : float - Drift rate. - a : float - Boundary separation. - z : float - Starting point [0, 1]. - t : float - Non-decision time [0, inf). - err : float - Error bound. - k_terms : int - Number of terms to use to approximate the PDF. - epsilon : float - A small positive number to prevent division by zero or taking the log of zero. - sv : float, optional - Standard deviation of the drift rate [0, inf). Default is 0. - - Returns - ------- - np.ndarray - The log likelihood of the drift diffusion model with the standard deviation of sv. - """ - if sv == 0: - return logp_ddm(data, v, a, z, t, err, k_terms, epsilon) - - data = pt.reshape(data, (-1, 2)).astype(pytensor.config.floatX) - rt = pt.abs(data[:, 0]) - response = data[:, 1] - flip = response > 0 - a = a * 2.0 - v_flipped = pt.switch(flip, -v, v) # transform v if x is upper-bound response - z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response - rt = rt - t - negative_rt = rt < 0 - rt = pt.switch(negative_rt, epsilon, rt) - - p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) - logp = pt.switch( - rt <= epsilon, - LOGP_LB, - pt.log(p) - + ( - (a * z_flipped * sv) ** 2 - - 2 * a * v_flipped * z_flipped - - (v_flipped**2) * rt - ) - / (2 * (sv**2) * rt + 2) - - 0.5 * pt.log(sv**2 * rt + 1) - - 2 * pt.log(a), - ) - return logp - # LBA diff --git a/tests/test_numerical_stability_utils.py b/tests/test_numerical_stability_utils.py deleted file mode 100644 index 779d553a..00000000 --- a/tests/test_numerical_stability_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Tests for the utilities to ensure numerical stability of likelihood functions.""" - -import numpy as np -import pytest - -from hssm.likelihoods.analytical import _Size, check_rt_log_domain, _log_bound_error_msg - - -def test_check_rt_log_domain(): - err = 1e-2 - epsilon = np.finfo(float).eps - for size in _Size: - bound = 1 / (np.pi * err) if size == _Size.LARGE else 1 / (8 * err**2 * np.pi) - rt = np.array([bound * (1 + epsilon)]) - - with pytest.raises(ValueError, match=r"^RTs must be less than"): - check_rt_log_domain(rt, err, size, clip_values=False) - - _rt = check_rt_log_domain(rt, err, size) - check_rt_log_domain(_rt, err, size, clip_values=False) - - -if __name__ == "__main__": - pytest.main([__file__]) From c56d654cd6c3029b4ee36f9d277ea499e71f0c19 Mon Sep 17 00:00:00 2001 From: Carlos Paniagua Date: Wed, 25 Sep 2024 23:31:31 -0400 Subject: [PATCH 11/11] Refactor LBA tests to improve parameter handling and add vectorization helpers --- tests/test_likelihoods_lba.py | 161 +++++++++++----------------------- 1 file changed, 50 insertions(+), 111 deletions(-) diff --git a/tests/test_likelihoods_lba.py b/tests/test_likelihoods_lba.py index 1508b095..4cbef88a 100644 --- a/tests/test_likelihoods_lba.py +++ b/tests/test_likelihoods_lba.py @@ -1,143 +1,82 @@ """Unit testing for LBA likelihood functions.""" -from pathlib import Path from itertools import product import numpy as np -import pandas as pd import pymc as pm -import pytensor -import pytensor.tensor as pt import pytest -import arviz as az -from pytensor.compile.nanguardmode import NanGuardMode import hssm # pylint: disable=C0413 from hssm.likelihoods.analytical import logp_lba2, logp_lba3 -from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox -from hssm.distribution_utils import make_likelihood_callable hssm.set_floatX("float32") CLOSE_TOLERANCE = 1e-4 -def test_lba2_basic(): - size = 1000 +def filter_theta(theta, exclude_keys=["A", "b"]): + """Filter out specific keys from the theta dictionary.""" + return {k: v for k, v in theta.items() if k not in exclude_keys} - lba_data_out = hssm.simulate_data( - model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size - ) - - # Test if vectorization ok across parameters - out_A_vec = logp_lba2( - lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0 - ).eval() - out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval() - assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) - - out_b_vec = logp_lba2( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=1.0, - v1=1.0, - ).eval() - assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) - - out_v_vec = logp_lba2( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=np.array([1.0] * size), - v1=np.array([1.0] * size), - ).eval() - assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) - # Test A > b leads to error +def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta): + """Helper function to assert ParameterValueError for given parameters.""" with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0 + logp_func( + lba_data_out.values, + A=A, + b=b, + **filter_theta(theta, ["A", "b"]), ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0 - ).eval() +def vectorize_param(theta, param, size): + """ + Vectorize a specific parameter in the theta dictionary. - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, - A=np.array([0.6] * 1000), - b=np.array([0.5] * 1000), - v0=1.0, - v1=1.0, - ).eval() + Parameters: + theta (dict): Dictionary of parameters. + param (str): The parameter to vectorize. + size (int): The size of the vector. + Returns: + dict: A new dictionary with the specified parameter vectorized. -def test_lba3_basic(): - size = 1000 + Examples: + >>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0} + >>> vectorize_param(theta, "A", 3) + {'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0} - lba_data_out = hssm.simulate_data( - model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size - ) - - # Test if vectorization ok across parameters - out_A_vec = logp_lba3( - lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - - out_base = logp_lba3( - lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - - assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) - - out_b_vec = logp_lba3( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=1.0, - v1=1.0, - v2=1.0, - ).eval() - assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) - - out_v_vec = logp_lba3( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=np.array([1.0] * size), - v1=np.array([1.0] * size), - v2=np.array([1.0] * size), - ).eval() - assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) + >>> vectorize_param(theta, "v0", 2) + {'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0} + """ + return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()} - # Test A > b leads to error - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval() +theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0) +theta_lba3 = theta_lba2 | {"v2": 1.0} - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0 - ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, - A=np.array([0.6] * 1000), - b=np.array([0.5] * 1000), - v0=1.0, - v1=1.0, - v2=1.0, - ).eval() +@pytest.mark.parametrize( + "logp_func, model, theta", + [(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)], +) +def test_lba(logp_func, model, theta): + size = 1000 + lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size) + + # Test if vectorization is ok across parameters + for param in theta: + param_vec = vectorize_param(theta, param, size) + out_vec = logp_func(lba_data_out.values, **param_vec).eval() + out_base = logp_func(lba_data_out.values, **theta).eval() + assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE) + + # Test A > b leads to error + A_values = [np.full(size, 0.6), 0.6] + b_values = [np.full(size, 0.5), 0.5] + + for A, b in product(A_values, b_values): + assert_parameter_value_error(logp_func, lba_data_out, A, b, theta)