Skip to content

Commit

Permalink
add working log likelihood computation to hssm class (#577)
Browse files Browse the repository at this point in the history
* add working log likelihood computation to hssm class

* chasing mypy errors

* help mypy with typing

* likelihood is now compile before the loop

* fix mypy complaints

* fix linter complains

* more mypy madness

* syntax error nonesense ?!

* next attempt at fixing linting, and some safeguard added to compile_pymc`

* first pass add in lba

* endless linter madness

* even more linter madness....

* fix issue with missing variables for likelihood computation

* fix inconsistency in which parameters were kept in the posterior

* attempt to fix mypy complaints

* next attempt at dealing with mypy

* add tests

* some refinements and corresponding test fixes

* clean

* mypy madness begins again

* update ssm-simulators dependency and allow 3.12 tests

* address change requests

* fix some remaining comments and correct problems with tests
  • Loading branch information
AlexanderFengler authored Sep 17, 2024
1 parent d7c9ead commit dac648c
Show file tree
Hide file tree
Showing 13 changed files with 876 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_fast_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_slow_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/pymc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ repository = "https://github.com/lnccbrown/HSSM"
keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
python = ">=3.10,<=3.12"
pymc = ">=5.16.2,<5.17.0"
arviz = "^0.19.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
ssm-simulators = "^0.7.5"
huggingface-hub = "^0.24.6"
bambi = ">=0.14.0,<0.15.0"
numpyro = "^0.15.2"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"
tqdm= "^4.66.0"
jax = { version = "^0.4.25", extras = ["cuda12"], optional = true }
numpy = ">=1.26.4,<2.0.0"

Expand Down
26 changes: 24 additions & 2 deletions src/hssm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Config:
model_name: SupportedModels | str
loglik_kind: LoglikKind
response: list[str] | None = None
choices: list[int] | None = None
list_params: list[str] | None = None
description: str | None = None
loglik: LogLik | None = None
Expand Down Expand Up @@ -63,6 +64,7 @@ def from_defaults(
model_name,
loglik_kind=kind,
response=default_config["response"],
choices=default_config["choices"],
list_params=default_config["list_params"],
description=default_config["description"],
**loglik_config,
Expand Down Expand Up @@ -90,6 +92,7 @@ def from_defaults(
model_name,
loglik_kind=loglik_kind,
response=default_config["response"],
choices=default_config["choices"],
list_params=default_config["list_params"],
description=default_config["description"],
**loglik_config,
Expand All @@ -98,6 +101,7 @@ def from_defaults(
model_name,
loglik_kind=loglik_kind,
response=default_config["response"],
choices=default_config["choices"],
list_params=default_config["list_params"],
description=default_config["description"],
)
Expand All @@ -117,18 +121,33 @@ def update_loglik(self, loglik: Any | None) -> None:

self.loglik = loglik

def update_choices(self, choices: list[int] | None) -> None:
"""Update the choices from user input.
Parameters
----------
choices : list[int]
A list of choices.
"""
if choices is None:
return

self.choices = choices

def update_config(self, user_config: ModelConfig) -> None:
"""Update the object from a ModelConfig object.
Parameters
----------
loglik : optional
A user-defined log-likelihood function.
user_config: ModelConfig
User specified ModelConfig used update self.
"""
if user_config.response is not None:
self.response = user_config.response
if user_config.list_params is not None:
self.list_params = user_config.list_params
if user_config.choices is not None:
self.choices = user_config.choices

if (
self.loglik_kind == "approx_differentiable"
Expand All @@ -146,6 +165,8 @@ def validate(self) -> None:
raise ValueError("Please provide `response` via `model_config`.")
if self.list_params is None:
raise ValueError("Please provide `list_params` via `model_config`.")
if self.choices is None:
raise ValueError("Please provide `choices` via `model_config`.")
if self.loglik is None:
raise ValueError("Please provide a log-likelihood function via `loglik`.")
if self.loglik_kind == "approx_differentiable" and self.backend is None:
Expand All @@ -170,6 +191,7 @@ class ModelConfig:

response: list[str] | None = None
list_params: list[str] | None = None
choices: list[int] | None = None
default_priors: dict[str, ParamSpec] = field(default_factory=dict)
bounds: dict[str, tuple[float, float]] = field(default_factory=dict)
backend: Literal["jax", "pytensor"] | None = None
Expand Down
48 changes: 48 additions & 0 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
ddm_params,
ddm_sdv_bounds,
ddm_sdv_params,
lba2_bounds,
lba2_params,
lba3_bounds,
lba3_params,
logp_ddm,
logp_ddm_sdv,
logp_lba2,
logp_lba3,
)
from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm
from .param import ParamSpec, _make_default_prior
Expand All @@ -32,6 +38,8 @@
"weibull",
"race_no_bias_angle_4",
"ddm_seq2_no_bias",
"lba3",
"lba2",
]

LoglikKind = Literal["analytical", "approx_differentiable", "blackbox"]
Expand Down Expand Up @@ -72,6 +80,7 @@ class DefaultConfig(TypedDict):

response: list[str]
list_params: list[str]
choices: list[int]
description: Optional[str]
likelihoods: LoglikConfigs

Expand All @@ -82,6 +91,7 @@ class DefaultConfig(TypedDict):
"ddm": {
"response": ["rt", "response"],
"list_params": ddm_params,
"choices": [-1, 1],
"description": "The Drift Diffusion Model (DDM)",
"likelihoods": {
"analytical": {
Expand Down Expand Up @@ -130,6 +140,7 @@ class DefaultConfig(TypedDict):
"ddm_sdv": {
"response": ["rt", "response"],
"list_params": ddm_sdv_params,
"choices": [-1, 1],
"description": "The Drift Diffusion Model (DDM) with standard deviation for v",
"likelihoods": {
"analytical": {
Expand Down Expand Up @@ -179,6 +190,7 @@ class DefaultConfig(TypedDict):
"full_ddm": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "t", "sz", "sv", "st"],
"choices": [-1, 1],
"description": "The full Drift Diffusion Model (DDM)",
"likelihoods": {
"blackbox": {
Expand All @@ -195,9 +207,40 @@ class DefaultConfig(TypedDict):
}
},
},
"lba2": {
"response": ["rt", "response"],
"list_params": lba2_params,
"choices": [0, 1],
"description": "Linear Ballistic Accumulator 2 Choices (LBA2)",
"likelihoods": {
"analytical": {
"loglik": logp_lba2,
"backend": None,
"default_priors": {},
"bounds": lba2_bounds,
"extra_fields": None,
}
},
},
"lba3": {
"response": ["rt", "response"],
"list_params": lba3_params,
"choices": [0, 1, 2],
"description": "Linear Ballistic Accumulator 3 Choices (LBA3)",
"likelihoods": {
"analytical": {
"loglik": logp_lba3,
"backend": None,
"default_priors": {},
"bounds": lba3_bounds,
"extra_fields": None,
}
},
},
"angle": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "t", "theta"],
"choices": [-1, 1],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand All @@ -218,6 +261,7 @@ class DefaultConfig(TypedDict):
"levy": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "alpha", "t"],
"choices": [-1, 1],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand All @@ -238,6 +282,7 @@ class DefaultConfig(TypedDict):
"ornstein": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "g", "t"],
"choices": [-1, 1],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand All @@ -258,6 +303,7 @@ class DefaultConfig(TypedDict):
"weibull": {
"response": ["rt", "response"],
"list_params": ["v", "a", "z", "t", "alpha", "beta"],
"choices": [-1, 1],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand All @@ -279,6 +325,7 @@ class DefaultConfig(TypedDict):
"race_no_bias_angle_4": {
"response": ["rt", "response"],
"list_params": ["v0", "v1", "v2", "v3", "a", "z", "t", "theta"],
"choices": [0, 1, 2, 3],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand All @@ -302,6 +349,7 @@ class DefaultConfig(TypedDict):
"ddm_seq2_no_bias": {
"response": ["rt", "response"],
"list_params": ["vh", "vl1", "vl2", "a", "t"],
"choices": [0, 1, 2, 3],
"description": None,
"likelihoods": {
"approx_differentiable": {
Expand Down
Loading

0 comments on commit dac648c

Please sign in to comment.