Skip to content

Commit

Permalink
Bump version (#63)
Browse files Browse the repository at this point in the history
* Bump version

* Switched several stateful errors to RuntimeErrors.

This is for compatibility with jaxlib>=0.3.5, which forcibly converts errors raised inside host_callback.call into RuntimeErrors.

* Bumped minimum JAX version for simplicity.
  • Loading branch information
patrick-kidger authored Apr 8, 2022
1 parent 0c624e8 commit f3b55d5
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _(In other words, why should you care? Because Equinox is really simple to learn
pip install equinox
```

Requires Python 3.7+ and JAX 0.2.18+.
Requires Python 3.7+ and JAX 0.3.4+.

## Documentation

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _(In other words, why should you care? Because Equinox is really simple to learn
pip install equinox
```

Requires Python 3.7+ and JAX 0.2.18+.
Requires Python 3.7+ and JAX 0.3.4+.

## Quick example

Expand Down
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from .update import apply_updates


__version__ = "0.3.2"
__version__ = "0.4.0"
41 changes: 17 additions & 24 deletions equinox/experimental/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
import jax
import jax.experimental.host_callback as hcb
import jax.interpreters.batching as batching


try:
import jax.interpreters.mlir as mlir
except ImportError:
mlir = None
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax.lax as lax
import jax.numpy as jnp
Expand Down Expand Up @@ -326,7 +321,7 @@ def _outside_call_batching_rule(

def _batchify_impl(*flat, treedef, like_batch_axes, current_batch_axes):
if current_batch_axes != like_batch_axes:
raise TypeError("`like` and the saved state have different batch axes")
raise RuntimeError("`like` and the saved state have different batch axes")
state, like = jax.tree_unflatten(treedef, flat)
return jax.tree_leaves(state)

Expand Down Expand Up @@ -370,12 +365,9 @@ def _batchify_batching_rule(
_batchify_p,
xla.lower_fun(_batchify_impl, multiple_results=True, new_style=True),
)
# The `mlir` module got added in later JAX versions.
# (Probably just `if mlir is not None` would suffice here?)
if hasattr(mlir, "lower_fun") and hasattr(mlir, "register_lowering"):
mlir.register_lowering(
_batchify_p, mlir.lower_fun(_batchify_impl, multiple_results=True)
)
mlir.register_lowering(
_batchify_p, mlir.lower_fun(_batchify_impl, multiple_results=True)
)


class _GetStateArg(Module):
Expand All @@ -390,9 +382,9 @@ def _get_state_hcb(arg: _GetStateArg) -> PyTree:
try:
current_state, current_batch_axes, _ = _state_cache[index._obj]
except KeyError as e:
raise KeyError("Cannot get state before it has been set") from e
raise RuntimeError("Cannot get state before it has been set") from e
if current_batch_axes != batch_axes:
raise TypeError("`like` and the saved state have different batch axes")
raise RuntimeError("`like` and the saved state have different batch axes")
return current_state


Expand Down Expand Up @@ -423,10 +415,11 @@ def get_state(index: StateIndex, like: PyTree[Array]) -> PyTree[Array]:
A `TypeError` at trace time if `like` is not a PyTree of JAX arrays.
A `TypeError` at run time if `like` is not of the same shape, dtype, PyTree
A `RuntimeError` at run time if `like` is not of the same shape, dtype, PyTree
structure, and batch axes as the retrieved value.
A `KeyError` at run time if no state has previously been saved with this `index`.
A `RuntimeError` at run time if no state has previously been saved with this
`index`.
!!! warning
Expand Down Expand Up @@ -483,14 +476,14 @@ def get_state(index: StateIndex, like: PyTree[Array]) -> PyTree[Array]:
index._obj
]
except KeyError as e:
raise KeyError("Cannot get state before it has been set") from e
raise RuntimeError("Cannot get state before it has been set") from e
if current_version == index._version.value:
state = index._state
else:
state = jax.tree_map(jnp.asarray, current_state)
_treedef = jax.tree_structure(state)
if _treedef != jax.tree_structure(state):
raise ValueError(
raise RuntimeError(
"`like` has different PyTree structure to the stored state"
)
flat, treedef = jax.tree_flatten((state, like))
Expand Down Expand Up @@ -530,11 +523,11 @@ def _set_state_hcb(arg: _SetStateArg) -> None:
current_state_shape = jax.eval_shape(lambda: current_state)
state_shape = jax.eval_shape(lambda: state)
if current_state_shape != state_shape:
raise TypeError(
raise RuntimeError(
"New state and old state have different shape, dtype, or PyTree structure"
)
if current_batch_axes != batch_axes:
raise TypeError("New state and old state have different batch axes")
raise RuntimeError("New state and old state have different batch axes")
_state_cache[index._obj] = (state, batch_axes, current_version + 1)


Expand All @@ -553,13 +546,13 @@ def set_state(index: StateIndex, state: PyTree[Array]) -> None:
**Raises:**
A `RuntimeError` at trace time if `index.inference` is truthy.
A `TypeError` at trace time if `state` is not a PyTree of JAX arrays.
A `TypeError` at run time if this `index` has previously been used to save a
A `RuntimeError` at run time if this `index` has previously been used to save a
`state` with a different shape, dtype, PyTree structure, or batch axes.
A `RuntimeError` at trace time if `index.inference` is truthy.
!!! info
The same `index` can be used multiple times, to overwrite a previously saved
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

python_requires = "~=3.7"

install_requires = ["jax>=0.2.26", "jaxlib>=0.1.76"]
install_requires = ["jax>=0.3.4"]

setuptools.setup(
name=name,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def test_batch_norm(getkey):

# Test that switching to a different amount of batching raises an error

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
jax.vmap(bn, axis_name="batch")(x1)

# Test that it normalises
Expand Down
30 changes: 15 additions & 15 deletions tests/test_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_no_nonjaxarray():
def test_no_set():
index = eqx.experimental.StateIndex()
a = jnp.array(2)
with pytest.raises(KeyError):
with pytest.raises(RuntimeError):
eqx.experimental.get_state(index, a)


Expand All @@ -70,9 +70,9 @@ def set_state2():
eqx.experimental.set_state(index2, jnp.array(1))
eqx.experimental.set_state(index2, [jnp.array(1)])

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
set_state1()
with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
set_state2()


Expand Down Expand Up @@ -123,21 +123,21 @@ def vmap_get_state(i, x):
vmap_set_state(index1, set_)
assert jnp.array_equal(vmap_get_state(index1, get_), set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# setting state without vmap, after setting state with vmap
eqx.experimental.set_state(index1, set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# getting state without vmap, after setting state with vmap
eqx.experimental.get_state(index1, get_)

eqx.experimental.set_state(index2, set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# setting state with vmap, after setting state without vmap
vmap_set_state(index2, set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# getting state with vmap, after setting state without vmap
vmap_get_state(index2, get_)

Expand Down Expand Up @@ -178,10 +178,10 @@ def get_state_bad(y):
set_state(set_)
assert jnp.array_equal(get_state(get_), set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
eqx.experimental.get_state(index, get_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
get_state_bad(get_)


Expand All @@ -193,7 +193,7 @@ def test_inference_not_set_state():

def test_inference_no_state():
index = eqx.experimental.StateIndex(inference=True)
with pytest.raises(KeyError):
with pytest.raises(RuntimeError):
eqx.experimental.get_state(index, jnp.array(1))


Expand All @@ -205,7 +205,7 @@ def test_inference_not_set_under_jit():
def f(i):
eqx.tree_at(lambda j: j.inference, i, True)

with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
f(index)


Expand Down Expand Up @@ -377,13 +377,13 @@ def vmap_get_state(i, x):
vmap_set_state(index1, set_)
assert jnp.array_equal(vmap_get_state(index1_inference, get_), set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# getting state without vmap, after setting state with vmap
eqx.experimental.get_state(index1_inference, get_)

eqx.experimental.set_state(index2, set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
# getting state with vmap, after setting state without vmap
vmap_get_state(index2_inference, get_)

Expand Down Expand Up @@ -425,10 +425,10 @@ def get_state_bad(y):
set_state(set_)
assert jnp.array_equal(get_state(get_), set_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
eqx.experimental.get_state(index_inference, get_)

with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
get_state_bad(get_)


Expand Down

0 comments on commit f3b55d5

Please sign in to comment.