Skip to content

Commit

Permalink
chex: improve definition of chex.ArrayDType
Browse files Browse the repository at this point in the history
Previously, `chex.ArrayDType` has resolved to `Any`. This makes it equivalent to `jax.typing.DTypeLike`, added in jax v0.4.19 (see jax-ml/jax#18042).

PiperOrigin-RevId: 572619589
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Oct 11, 2023
1 parent 381fc8e commit c69eeb3
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions chex/_src/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Iterable, Mapping, Union

import jax
import jax.numpy as jnp
import numpy as np

# Special types of arrays.
Expand All @@ -33,9 +32,12 @@
# Generic array type.
# Similar to `jax.typing.ArrayLike` but does not accept python scalar types.
Array = Union[
ArrayDevice, ArrayBatched, ArraySharded, # JAX array type
ArrayDevice,
ArrayBatched,
ArraySharded, # JAX array type
ArrayNumpy, # NumPy array type
np.bool_, np.number, # NumPy scalar types
np.bool_,
np.number, # NumPy scalar types
]

# A tree of generic arrays.
Expand All @@ -54,4 +56,10 @@
PRNGKey = jax.Array
PyTreeDef = jax.tree_util.PyTreeDef
Device = jax.Device
ArrayDType = type(jnp.float32)

# TODO(iukemaev, jakevdp): upgrade minimum jax version & remove this condition.
if hasattr(jax.typing, 'DTypeLike'):
# jax version 0.4.19 or newer
ArrayDType = jax.typing.DTypeLike
else:
ArrayDType = Any

0 comments on commit c69eeb3

Please sign in to comment.