Skip to content

Commit

Permalink
Bump the minimum Jax version to 0.4.6.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 518893518
  • Loading branch information
hbq1 authored and ChexDev committed Mar 23, 2023
1 parent 20959e8 commit 484e6ad
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 22 deletions.
6 changes: 1 addition & 5 deletions chex/_src/asserts_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ def is_traceable(fn) -> bool:
"""

fn_string_tokens = (
"_python_jit.", # PyJIT in Python ver. < 3.7
"_cpp_jit.", # CppJIT in Python ver. < 3.7 (deprecated)
".reraise_with_filtered_traceback", # JIT in Python ver. >= 3.7
"CompiledFunction", # C++ JIT in jaxlib 0.1.66 or newer.
"pmap.", # Python pmap
Expand All @@ -354,7 +352,6 @@ def is_traceable(fn) -> bool:
)

fn_type_tokens = (
"CompiledFunction",
"PmapFunction",
"PjitFunction",
)
Expand All @@ -381,8 +378,7 @@ def is_traceable(fn) -> bool:
return True

try:
if isinstance(fn_, (jax.lib.xla_extension.jax_jit.CompiledFunction,
jax.lib.xla_extension.PjitFunction)):
if isinstance(fn_, jax.lib.xla_extension.PjitFunction):
return True
except AttributeError:
pass
Expand Down
23 changes: 7 additions & 16 deletions chex/_src/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@

# Special types of arrays.
ArrayNumpy = np.ndarray

# For instance checking, use `isinstance(x, jax.Array)`.
if hasattr(jax, 'Array'):
ArrayDevice = jax.Array # jax >= 0.3.20
ArraySharded = jax.Array
ArrayBatched = jax.Array
elif hasattr(jax.interpreters.xla, '_DeviceArray'): # 0.2.5 < jax < 0.3.20
ArrayDevice = jax.interpreters.xla._DeviceArray # pylint:disable=protected-access
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
ArrayBatched = jax.interpreters.batching.BatchTracer
else: # jax <= 0.2.5
ArrayDevice = jax.interpreters.xla.DeviceArray
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
ArrayBatched = jax.interpreters.batching.BatchTracer
ArrayDevice = jax.Array

# Types for backward compatibility.
ArraySharded = jax.Array
ArrayBatched = jax.Array

# Generic array type.
# Similar to `jax.typing.ArrayLike` but does not accept python scalar types.
Expand All @@ -59,8 +53,5 @@
Shape = jax.core.Shape
PRNGKey = jax.random.KeyArray
PyTreeDef = jax.tree_util.PyTreeDef
if hasattr(jax, 'Device'):
Device = jax.Device # jax >= 0.4.3
else:
Device = jax.lib.xla_extension.Device
Device = jax.Device
ArrayDType = type(jnp.float32)
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
absl-py>=0.9.0
typing_extensions>=4.2.0; python_version<"3.11"
dm-tree>=0.1.5
jax>=0.1.55
jax>=0.4.6
jaxlib>=0.1.37
numpy>=1.18.0
toolz>=0.9.0

0 comments on commit 484e6ad

Please sign in to comment.