Skip to content

Commit

Permalink
update ssm-simulators dependency and allow 3.12 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Sep 15, 2024
1 parent b042ef3 commit d5b083d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_fast_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_slow_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Checkout repository
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ repository = "https://github.com/lnccbrown/HSSM"
keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
python = ">=3.10,<=3.12"
pymc = ">=5.16.2,<5.17.0"
arviz = "^0.19.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
ssm-simulators = "^0.7.5"
huggingface-hub = "^0.24.6"
bambi = ">=0.14.0,<0.15.0"
numpyro = "^0.15.2"
Expand Down
6 changes: 2 additions & 4 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,7 @@ def sample(
self.log_likelihood(self._inference_obj, inplace=True)

# Subset data vars in posterior
if self._inference_obj is not None:
self._clean_posterior_group(idata=self._inference_obj)
self._clean_posterior_group(idata=self._inference_obj)
return self.traces

def vi(
Expand Down Expand Up @@ -708,8 +707,7 @@ def vi(
self._inference_obj_vi = self._vi_approx.sample(draws)

# Post-processing
if self._inference_obj_vi is not None:
self._clean_posterior_group(idata=self._inference_obj_vi)
self._clean_posterior_group(idata=self._inference_obj_vi)

# Return the InferenceData object if return_idata is True
if return_idata:
Expand Down

0 comments on commit d5b083d

Please sign in to comment.