Skip to content

Commit

Permalink
Merge pull request #443 from lnccbrown/442-jax-has-deprecated-the-jax…
Browse files Browse the repository at this point in the history
…config-module-from-0429

Update reference to jax.config
  • Loading branch information
AlexanderFengler authored May 23, 2024
2 parents 80c5248 + d9c94f3 commit a135502
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 22 deletions.
23 changes: 9 additions & 14 deletions docs/tutorials/main_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -407,21 +407,20 @@
"import os\n",
"import sys\n",
"import time\n",
"from matplotlib import pyplot as plt\n",
"\n",
"import arviz as az # Visualization\n",
"import pytensor # Graph-based tensor library\n",
"import hssm\n",
"import bambi as bmb\n",
"\n",
"# import ssms.basic_simulators # Model simulators\n",
"import hddm_wfpt\n",
"import bambi as bmb\n",
"import hssm\n",
"import jax\n",
"import pytensor # Graph-based tensor library\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# Setting float precision in pytensor\n",
"pytensor.config.floatX = \"float32\"\n",
"\n",
"from jax.config import config\n",
"\n",
"config.update(\"jax_enable_x64\", False)"
"jax.config.update(\"jax_enable_x64\", False)"
]
},
{
Expand Down Expand Up @@ -4004,9 +4003,7 @@
}
],
"source": [
"from jax.config import config\n",
"\n",
"config.update(\"jax_enable_x64\", False)\n",
"jax.config.update(\"jax_enable_x64\", False)\n",
"infer_data_angle = model_angle.sample(\n",
" sampler=\"nuts_numpyro\",\n",
" chains=2,\n",
Expand Down Expand Up @@ -10117,9 +10114,7 @@
}
],
"source": [
"from jax.config import config\n",
"\n",
"config.update(\"jax_enable_x64\", False)\n",
"jax.config.update(\"jax_enable_x64\", False)\n",
"model_reg_v_angle_hier.sample(\n",
" sampler=\"nuts_numpyro\", chains=2, cores=1, draws=1000, tune=1000\n",
")"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bambi = "^0.13.0"
numpyro = "^0.15.0"
hddm-wfpt = "^0.1.4"
seaborn = "^0.13.2"
jax = { version = ">=0.4.23,<0.4.28", extras = ["cuda12"], optional = true }
jax = { version = "^0.4.25", extras = ["cuda12"], optional = true }

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand Down
10 changes: 5 additions & 5 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from typing import Any, Iterable, Literal, NewType

import bambi as bmb
import jax
import numpy as np
import pandas as pd
import pytensor
import xarray as xr
from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm
from huggingface_hub import hf_hub_download
from jax import config
from pymc.model_graph import ModelGraph
from pytensor import function

Expand Down Expand Up @@ -258,7 +258,7 @@ def make_graph(
return graph


def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True):
def set_floatX(dtype: Literal["float32", "float64"], update_jax: bool = True):
"""Set float types for pytensor and Jax.
Often we wish to work with a specific type of float in both PyTensor and JAX.
Expand All @@ -268,7 +268,7 @@ def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True):
----------
dtype
Either `float32` or `float64`. Float type for pytensor (and jax if `jax=True`).
jax : optional
update_jax : optional
Whether this function also sets float type for JAX by changing the
`jax_enable_x64` setting in JAX config. Defaults to True.
"""
Expand All @@ -278,9 +278,9 @@ def set_floatX(dtype: Literal["float32", "float64"], jax: bool = True):
pytensor.config.floatX = dtype
_logger.info("Setting PyTensor floatX type to %s.", dtype)

if jax:
if update_jax:
jax_enable_x64 = dtype == "float64"
config.update("jax_enable_x64", jax_enable_x64)
jax.config.update("jax_enable_x64", jax_enable_x64)

_logger.info(
'Setting "jax_enable_x64" to %s. '
Expand Down
2 changes: 1 addition & 1 deletion tests/slow/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from hssm.utils import _rearrange_data

hssm.set_floatX("float32", jax=True)
hssm.set_floatX("float32", update_jax=True)

# AF-TODO: Include more tests that use different link functions!

Expand Down
2 changes: 1 addition & 1 deletion tests/test_hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hssm.utils import download_hf
from hssm.likelihoods import DDM, logp_ddm

hssm.set_floatX("float32", jax=True)
hssm.set_floatX("float32", update_jax=True)

param_v = {
"name": "v",
Expand Down

0 comments on commit a135502

Please sign in to comment.