diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 0359aaa64..d227a0943 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -4,7 +4,7 @@ from collections import namedtuple from contextlib import ExitStack, contextmanager import functools -from typing import Callable, Generator, Optional, cast +from typing import Callable, Generator, Optional, Union, cast import warnings import jax @@ -235,7 +235,7 @@ def sample( def param( - name: str, init_value: Optional[ArrayLike | Callable] = None, **kwargs + name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, **kwargs ) -> Array: """ Annotate the given site as an optimizable parameter for use with @@ -319,7 +319,9 @@ def deterministic(name: str, value: ArrayLike) -> Array: return msg["value"] -def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike | None: +def mutable( + name: str, init_value: Optional[ArrayLike] = None +) -> Union[ArrayLike, None]: """ This primitive is used to store a mutable value that can be changed during model execution:: @@ -375,7 +377,7 @@ def _inspect() -> dict: return msg -def get_mask() -> ArrayLike | None: +def get_mask() -> Union[ArrayLike, None]: """ Records the effects of enclosing ``handlers.mask`` handlers. This is useful for avoiding expensive ``numpyro.factor()`` computations during @@ -641,7 +643,7 @@ def factor(name: str, log_factor: ArrayLike) -> None: sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True}) -def prng_key() -> Array | None: +def prng_key() -> Union[Array, None]: """ A statement to draw a pseudo-random number generator key :func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler.