diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 893814f0f..6c0711142 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -3,7 +3,7 @@ from collections import OrderedDict from functools import partial -from typing import Callable +from typing import Callable, Optional import jax from jax import device_put, lax, random @@ -348,7 +348,7 @@ def scan( f: Callable, init, xs, - length: int | None = None, + length: Optional[int] = None, reverse: bool = False, history: int = 1, ): diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 912ed46f2..02a68acf3 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict, namedtuple -from typing import Any, Callable, OrderedDict as OrderedDictType +from typing import Any, Callable, OrderedDict as OrderedDictType, Union import jax from jax import random @@ -21,9 +21,9 @@ SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) -RunInferenceResult = ( - dict[str, Any] | tuple[AutoGuide, dict[str, Any]] -) # for mcmc or sdvi +RunInferenceResult = Union[ + dict[str, Any], tuple[AutoGuide, dict[str, Any]] +] # for mcmc or sdvi class StochasticSupportInference(ABC): @@ -124,12 +124,12 @@ def _combine_inferences( branching_traces: dict[str, OrderedDictType], *args: Any, **kwargs: Any, - ) -> DCCResult | SDVIResult: + ) -> Union[DCCResult, SDVIResult]: raise NotImplementedError def run( self, rng_key: ArrayLike, *args: Any, **kwargs: Any - ) -> DCCResult | SDVIResult: + ) -> Union[DCCResult, SDVIResult]: """ Run inference on each SLP separately and combine the results. diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index 527bd9b24..a68631b6a 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -7,6 +7,7 @@ from collections import OrderedDict from itertools import product +from typing import Union import numpy as np @@ -230,7 +231,7 @@ def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray: def summary( - samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True + samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True ) -> dict: """ Returns a summary table displaying diagnostics of ``samples`` from the @@ -284,7 +285,7 @@ def summary( def print_summary( - samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True + samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True ) -> None: """ Prints a summary table displaying diagnostics of ``samples`` from the diff --git a/numpyro/util.py b/numpyro/util.py index 828eb1ded..3c710f41a 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -10,7 +10,7 @@ import random import re from threading import Lock -from typing import Any, Callable, Generator +from typing import Any, Callable, Generator, Optional import warnings import numpy as np @@ -27,7 +27,7 @@ _CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3' -def set_rng_seed(rng_seed: int | None = None) -> None: +def set_rng_seed(rng_seed: Optional[int] = None) -> None: """ Initializes internal state for the Python and NumPy random number generators. @@ -49,7 +49,7 @@ def enable_x64(use_x64: bool = True) -> None: jax.config.update("jax_enable_x64", use_x64) -def set_platform(platform: str | None = None) -> None: +def set_platform(platform: Optional[str] = None) -> None: """ Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program. @@ -408,7 +408,7 @@ def loop_fn(collection): def soft_vmap( - fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None + fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: Optional[int] = None ) -> Any: """ Vectorizing map that maps a function `fn` over `batch_ndims` leading axes @@ -466,7 +466,7 @@ def format_shapes( *, compute_log_prob: bool = False, title: str = "Trace Shapes:", - last_site: str | None = None, + last_site: Optional[str] = None, ): """ Given the trace of a function, returns a string showing a table of the shapes of