Skip to content

Commit

Permalink
Add sqrt_τ and refactor k_small function for improved calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpaniaguam committed Sep 26, 2024
1 parent 1eef96a commit c9c590f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 89 deletions.
70 changes: 5 additions & 65 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
https://gist.github.com/sammosummo/c1be633a74937efaca5215da776f194b.
"""

from enum import Enum
from typing import Type

import numpy as np
Expand All @@ -21,6 +20,7 @@

π = np.pi
τ = 2 * π
sqrt_τ = pt.sqrt(τ)
log_π = pt.log(π)
log_τ = pt.log(τ)
log_4 = pt.log(4)
Expand All @@ -29,6 +29,7 @@
def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return pt.max(pt.stack([a, b]), axis=0)


def k_small(rt: np.ndarray, err: float) -> np.ndarray:
"""Determine number of terms needed for small-t expansion.
Expand All @@ -44,14 +45,14 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_small.
"""
sqrt_τ_rt = pt.sqrt(τ * rt)
sqrt_rt = pt.sqrt(rt)
log_rt = pt.log(rt)
rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt)

ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2)
ks = _max(ks, pt.sqrt(rt) + 1)
ks = _max(ks, sqrt_rt + 1)

condition = 2 * sqrt_τ_rt * err < 1
condition = 2 * sqrt_τ * sqrt_rt * err < 1
ks = pt.switch(condition, ks, 2)

return ks
Expand Down Expand Up @@ -408,67 +409,6 @@ def logp_ddm_sdv(
bounds=ddm_sdv_bounds,
)


def logp_ddm_combined(data, v, a, z, t, err, k_terms, epsilon, sv=0):
"""
Compute the log likelihood of the drift diffusion model with optional standard deviation of the drift rate.
Parameters
----------
data : np.ndarray
The data array containing response times and responses.
v : float
Drift rate.
a : float
Boundary separation.
z : float
Starting point [0, 1].
t : float
Non-decision time [0, inf).
err : float
Error bound.
k_terms : int
Number of terms to use to approximate the PDF.
epsilon : float
A small positive number to prevent division by zero or taking the log of zero.
sv : float, optional
Standard deviation of the drift rate [0, inf). Default is 0.
Returns
-------
np.ndarray
The log likelihood of the drift diffusion model with the standard deviation of sv.
"""
if sv == 0:
return logp_ddm(data, v, a, z, t, err, k_terms, epsilon)

data = pt.reshape(data, (-1, 2)).astype(pytensor.config.floatX)
rt = pt.abs(data[:, 0])
response = data[:, 1]
flip = response > 0
a = a * 2.0
v_flipped = pt.switch(flip, -v, v) # transform v if x is upper-bound response
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response
rt = rt - t
negative_rt = rt < 0
rt = pt.switch(negative_rt, epsilon, rt)

p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB))
logp = pt.switch(
rt <= epsilon,
LOGP_LB,
pt.log(p)
+ (
(a * z_flipped * sv) ** 2
- 2 * a * v_flipped * z_flipped
- (v_flipped**2) * rt
)
/ (2 * (sv**2) * rt + 2)
- 0.5 * pt.log(sv**2 * rt + 1)
- 2 * pt.log(a),
)
return logp

# LBA


Expand Down
24 changes: 0 additions & 24 deletions tests/test_numerical_stability_utils.py

This file was deleted.

0 comments on commit c9c590f

Please sign in to comment.