diff --git a/docs/requirements.txt b/docs/requirements.txt index 31b8549548..03a7fa7546 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ funsor ipython jax jaxlib -jaxns==2.4.8 +jaxns>=2.6.2 Jinja2 matplotlib multipledispatch diff --git a/setup.py b/setup.py index b47b0bc92b..89d78ce44e 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "flax", "funsor>=0.4.1", "graphviz", - "jaxns==2.4.8", + "jaxns>=2.6.2", "matplotlib", "optax>=0.0.6", "pylab-sdk", # jaxns dependency