Skip to content

Commit

Permalink
use Union
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 20, 2024
1 parent 75abe66 commit 02e149a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
from typing import Callable, Generator, Optional, cast
from typing import Callable, Generator, Optional, Union, cast
import warnings

import jax
Expand Down Expand Up @@ -235,7 +235,7 @@ def sample(


def param(
name: str, init_value: Optional[ArrayLike | Callable] = None, **kwargs
name: str, init_value: Optional[Union[ArrayLike, Callable]] = None, **kwargs
) -> Array:
"""
Annotate the given site as an optimizable parameter for use with
Expand Down Expand Up @@ -319,7 +319,9 @@ def deterministic(name: str, value: ArrayLike) -> Array:
return msg["value"]


def mutable(name: str, init_value: Optional[ArrayLike] = None) -> ArrayLike | None:
def mutable(
name: str, init_value: Optional[ArrayLike] = None
) -> Union[ArrayLike, None]:
"""
This primitive is used to store a mutable value that can be changed
during model execution::
Expand Down Expand Up @@ -375,7 +377,7 @@ def _inspect() -> dict:
return msg


def get_mask() -> ArrayLike | None:
def get_mask() -> Union[ArrayLike, None]:
"""
Records the effects of enclosing ``handlers.mask`` handlers.
This is useful for avoiding expensive ``numpyro.factor()`` computations during
Expand Down Expand Up @@ -641,7 +643,7 @@ def factor(name: str, log_factor: ArrayLike) -> None:
sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True})


def prng_key() -> Array | None:
def prng_key() -> Union[Array, None]:
"""
A statement to draw a pseudo-random number generator key
:func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler.
Expand Down

0 comments on commit 02e149a

Please sign in to comment.