Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

578 improve analytical likelihood #585

Merged
merged 11 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/hssm/likelihoods/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@

LOGP_LB = pm.floatX(-66.1)

π = np.pi
τ = 2 * π
sqrt_τ = pt.sqrt(τ)
log_π = pt.log(π)
log_τ = pt.log(τ)
log_4 = pt.log(4)


def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return pt.max(pt.stack([a, b]), axis=0)


def k_small(rt: np.ndarray, err: float) -> np.ndarray:
"""Determine number of terms needed for small-t expansion.
Expand All @@ -34,9 +45,15 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_small.
"""
ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err))
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0)
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2)
sqrt_rt = pt.sqrt(rt)
log_rt = pt.log(rt)
rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt)

ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2)
ks = _max(ks, sqrt_rt + 1)

condition = 2 * sqrt_τ * sqrt_rt * err < 1
ks = pt.switch(condition, ks, 2)

return ks

Expand All @@ -56,9 +73,16 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray:
np.ndarray
A 1D at array of k_large.
"""
kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt))
kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0)
kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt)))
log_rt = pt.log(rt)
sqrt_rt = pt.sqrt(rt)
log_err = pt.log(err)

π_rt_err = π * rt * err
π_sqrt_rt = π * sqrt_rt

kl = pt.sqrt(-2 * (log_π + log_err + log_rt)) / π_sqrt_rt
kl = _max(kl, 1.0 / pt.sqrt(π_sqrt_rt))
kl = pt.switch(π_rt_err < 1, kl, 1.0 / π_sqrt_rt)

return kl

Expand Down Expand Up @@ -141,7 +165,7 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
c = pt.max(r, axis=0)
p = pt.exp(c) * pt.sum(y * pt.exp(r - c), axis=0)
# Normalize p
p = p / pt.sqrt(2 * np.pi * pt.power(tt, 3))
p = p / pt.sqrt(2 * π * pt.power(tt, 3))

return p

Expand All @@ -167,9 +191,9 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray:
The approximated function f(tt|0, 1, w).
"""
k = get_ks(k_terms, fast=False)
y = k * pt.sin(k * np.pi * w)
r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2
p = pt.sum(y * pt.exp(r), axis=0) * np.pi
y = k * pt.sin(k * π * w)
r = -pt.power(k, 2) * pt.power(π, 2) * tt / 2
p = pt.sum(y * pt.exp(r), axis=0) * π

return p

Expand Down
161 changes: 50 additions & 111 deletions tests/test_likelihoods_lba.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,82 @@
"""Unit testing for LBA likelihood functions."""

from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
import pytest
import arviz as az
from pytensor.compile.nanguardmode import NanGuardMode

import hssm

# pylint: disable=C0413
from hssm.likelihoods.analytical import logp_lba2, logp_lba3
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")

CLOSE_TOLERANCE = 1e-4


def test_lba2_basic():
size = 1000
def filter_theta(theta, exclude_keys=["A", "b"]):
"""Filter out specific keys from the theta dictionary."""
return {k: v for k, v in theta.items() if k not in exclude_keys}

lba_data_out = hssm.simulate_data(
model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba2(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0
).eval()
out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval()
assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta):
"""Helper function to assert ParameterValueError for given parameters."""
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0
logp_func(
lba_data_out.values,
A=A,
b=b,
**filter_theta(theta, ["A", "b"]),
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0
).eval()
def vectorize_param(theta, param, size):
"""
Vectorize a specific parameter in the theta dictionary.

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
).eval()
Parameters:
theta (dict): Dictionary of parameters.
param (str): The parameter to vectorize.
size (int): The size of the vector.

Returns:
dict: A new dictionary with the specified parameter vectorized.

def test_lba3_basic():
size = 1000
Examples:
>>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0}
>>> vectorize_param(theta, "A", 3)
{'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0}

lba_data_out = hssm.simulate_data(
model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba3(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

out_base = logp_lba3(
lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
v2=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)
>>> vectorize_param(theta, "v0", 2)
{'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0}
"""
return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()}

# Test A > b leads to error
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval()
theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0)
theta_lba3 = theta_lba2 | {"v2": 1.0}

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
@pytest.mark.parametrize(
"logp_func, model, theta",
[(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)],
)
def test_lba(logp_func, model, theta):
size = 1000
lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size)

# Test if vectorization is ok across parameters
for param in theta:
param_vec = vectorize_param(theta, param, size)
out_vec = logp_func(lba_data_out.values, **param_vec).eval()
out_base = logp_func(lba_data_out.values, **theta).eval()
assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
A_values = [np.full(size, 0.6), 0.6]
b_values = [np.full(size, 0.5), 0.5]

for A, b in product(A_values, b_values):
assert_parameter_value_error(logp_func, lba_data_out, A, b, theta)