From 6a4206221fb6cc840ae143f31caa02f47e1f7c1d Mon Sep 17 00:00:00 2001 From: ChexDev Date: Thu, 12 Oct 2023 14:57:41 -0700 Subject: [PATCH] Fix warning of the form DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any] PiperOrigin-RevId: 573020724 --- chex/_src/pytypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index 1b2b4ea..8265f44 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,7 +14,7 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Iterable, Mapping, Union +from typing import Any, Iterable, Mapping, Union, Sequence import jax import numpy as np @@ -52,7 +52,7 @@ # Other types. Scalar = Union[float, int] Numeric = Union[Array, Scalar] -Shape = jax.core.Shape +Shape = Sequence[int | Any] PRNGKey = jax.Array PyTreeDef = jax.tree_util.PyTreeDef Device = jax.Device