diff --git a/docs/tutorials/main_tutorial.ipynb b/docs/tutorials/main_tutorial.ipynb index 7bf714e8..24724f5d 100644 --- a/docs/tutorials/main_tutorial.ipynb +++ b/docs/tutorials/main_tutorial.ipynb @@ -407,21 +407,20 @@ "import os\n", "import sys\n", "import time\n", - "from matplotlib import pyplot as plt\n", + "\n", "import arviz as az # Visualization\n", - "import pytensor # Graph-based tensor library\n", - "import hssm\n", + "import bambi as bmb\n", "\n", "# import ssms.basic_simulators # Model simulators\n", "import hddm_wfpt\n", - "import bambi as bmb\n", + "import hssm\n", + "import jax\n", + "import pytensor # Graph-based tensor library\n", + "from matplotlib import pyplot as plt\n", "\n", "# Setting float precision in pytensor\n", "pytensor.config.floatX = \"float32\"\n", - "\n", - "from jax.config import config\n", - "\n", - "config.update(\"jax_enable_x64\", False)" + "jax.config.update(\"jax_enable_x64\", False)" ] }, { @@ -4004,9 +4003,7 @@ } ], "source": [ - "from jax.config import config\n", - "\n", - "config.update(\"jax_enable_x64\", False)\n", + "jax.config.update(\"jax_enable_x64\", False)\n", "infer_data_angle = model_angle.sample(\n", " sampler=\"nuts_numpyro\",\n", " chains=2,\n", @@ -10117,9 +10114,7 @@ } ], "source": [ - "from jax.config import config\n", - "\n", - "config.update(\"jax_enable_x64\", False)\n", + "jax.config.update(\"jax_enable_x64\", False)\n", "model_reg_v_angle_hier.sample(\n", " sampler=\"nuts_numpyro\", chains=2, cores=1, draws=1000, tune=1000\n", ")" diff --git a/pyproject.toml b/pyproject.toml index 661f3450..48926d3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ bambi = "^0.13.0" numpyro = "^0.15.0" hddm-wfpt = "^0.1.4" seaborn = "^0.13.2" -jax = { version = ">=0.4.23,<0.4.28", extras = ["cuda12"], optional = true } +jax = { version = "^0.4.25", extras = ["cuda12"], optional = true } [tool.poetry.group.dev.dependencies] pytest = "^8.2.0" diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 179d0915..b6c15e54 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -13,13 +13,13 @@ from typing import Any, Iterable, Literal, NewType import bambi as bmb +import jax import numpy as np import pandas as pd import pytensor import xarray as xr from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm from huggingface_hub import hf_hub_download -from jax import config from pymc.model_graph import ModelGraph from pytensor import function @@ -258,7 +258,7 @@ def make_graph( return graph -def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True): +def set_floatX(dtype: Literal["float32", "float64"], update_jax: bool = True): """Set float types for pytensor and Jax. Often we wish to work with a specific type of float in both PyTensor and JAX. @@ -268,7 +268,7 @@ def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True): ---------- dtype Either `float32` or `float64`. Float type for pytensor (and jax if `jax=True`). - jax : optional + update_jax : optional Whether this function also sets float type for JAX by changing the `jax_enable_x64` setting in JAX config. Defaults to True. """ @@ -278,9 +278,9 @@ def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True): pytensor.config.floatX = dtype _logger.info("Setting PyTensor floatX type to %s.", dtype) - if jax: + if update_jax: jax_enable_x64 = dtype == "float64" - config.update("jax_enable_x64", jax_enable_x64) + jax.config.update("jax_enable_x64", jax_enable_x64) _logger.info( 'Setting "jax_enable_x64" to %s. ' diff --git a/tests/slow/test_mcmc.py b/tests/slow/test_mcmc.py index c2d508f8..2ddc6274 100644 --- a/tests/slow/test_mcmc.py +++ b/tests/slow/test_mcmc.py @@ -10,7 +10,7 @@ from hssm.utils import _rearrange_data -hssm.set_floatX("float32", jax=True) +hssm.set_floatX("float32", update_jax=True) # AF-TODO: Include more tests that use different link functions! diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 76c0e3d7..88ca46e3 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -7,7 +7,7 @@ from hssm.utils import download_hf from hssm.likelihoods import DDM, logp_ddm -hssm.set_floatX("float32", jax=True) +hssm.set_floatX("float32", update_jax=True) param_v = { "name": "v",