diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index a3eae2af0..1ae30f469 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -24,7 +24,7 @@ ] -def _compute_chain_variance_stats(x): +def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: # compute within-chain variance and variance estimator # input has shape C x N x sample_shape C, N = x.shape[:2] @@ -40,7 +40,7 @@ def _compute_chain_variance_stats(x): return var_within, var_estimator -def gelman_rubin(x): +def gelman_rubin(x: np.ndarray) -> np.ndarray: """ Computes R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -59,7 +59,7 @@ def gelman_rubin(x): return rhat -def split_gelman_rubin(x): +def split_gelman_rubin(x: np.ndarray) -> np.ndarray: """ Computes split R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -78,7 +78,7 @@ def split_gelman_rubin(x): return split_rhat -def _fft_next_fast_len(target): +def _fft_next_fast_len(target: int) -> int: # find the smallest number >= N such that the only divisors are 2, 3, 5 # works just like scipy.fftpack.next_fast_len if target <= 2: @@ -96,7 +96,7 @@ def _fft_next_fast_len(target): target += 1 -def autocorrelation(x, axis=0, bias=True): +def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: """ Computes the autocorrelation of samples at dimension ``axis``. @@ -140,7 +140,7 @@ def autocorrelation(x, axis=0, bias=True): return np.swapaxes(autocorr, axis, -1) -def autocovariance(x, axis=0, bias=True): +def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray: """ Computes the autocovariance of samples at dimension ``axis``. @@ -153,7 +153,7 @@ def autocovariance(x, axis=0, bias=True): return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True) -def effective_sample_size(x, bias=True): +def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray: """ Computes effective sample size of input ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. @@ -201,7 +201,7 @@ def effective_sample_size(x, bias=True): return n_eff -def hpdi(x, prob=0.90, axis=0): +def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray: """ Computes "highest posterior density interval" (HPDI) which is the narrowest interval with probability mass ``prob``. @@ -229,7 +229,9 @@ def hpdi(x, prob=0.90, axis=0): return np.concatenate([hpd_left, hpd_right], axis=axis) -def summary(samples, prob=0.90, group_by_chain=True): +def summary( + samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True +) -> dict: """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, @@ -279,7 +281,9 @@ def summary(samples, prob=0.90, group_by_chain=True): return summary_dict -def print_summary(samples, prob=0.90, group_by_chain=True): +def print_summary( + samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True +) -> None: """ Prints a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, diff --git a/numpyro/patch.py b/numpyro/patch.py index 20b3d8286..b4d22005c 100644 --- a/numpyro/patch.py +++ b/numpyro/patch.py @@ -1,8 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from types import ModuleType +from typing import Callable -def patch_dependency(target, root_module): + +def patch_dependency(target: str, root_module: ModuleType) -> Callable: parts = target.split(".") assert parts[0] == root_module.__name__ module = root_module diff --git a/numpyro/util.py b/numpyro/util.py index 2c5b2a0b6..828eb1ded 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -10,6 +10,7 @@ import random import re from threading import Lock +from typing import Any, Callable, Generator import warnings import numpy as np @@ -26,7 +27,7 @@ _CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3' -def set_rng_seed(rng_seed): +def set_rng_seed(rng_seed: int | None = None) -> None: """ Initializes internal state for the Python and NumPy random number generators. @@ -36,7 +37,7 @@ def set_rng_seed(rng_seed): np.random.seed(rng_seed) -def enable_x64(use_x64=True): +def enable_x64(use_x64: bool = True) -> None: """ Changes the default array type to use 64 bit precision as in NumPy. @@ -44,11 +45,11 @@ def enable_x64(use_x64=True): else 32 bits. """ if not use_x64: - use_x64 = os.getenv("JAX_ENABLE_X64", 0) - jax.config.update("jax_enable_x64", bool(use_x64)) + use_x64 = bool(os.getenv("JAX_ENABLE_X64", 0)) + jax.config.update("jax_enable_x64", use_x64) -def set_platform(platform=None): +def set_platform(platform: str | None = None) -> None: """ Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program. @@ -60,7 +61,7 @@ def set_platform(platform=None): jax.config.update("jax_platform_name", platform) -def set_host_device_count(n): +def set_host_device_count(n: int) -> None: """ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are `n` host (CPU) devices available to use. As a consequence, this @@ -79,9 +80,9 @@ def set_host_device_count(n): :param int n: number of CPU devices to use. """ - xla_flags = os.getenv("XLA_FLAGS", "") + xla_flags_str = os.getenv("XLA_FLAGS", "") xla_flags = re.sub( - r"--xla_force_host_platform_device_count=\S+", "", xla_flags + r"--xla_force_host_platform_device_count=\S+", "", xla_flags_str ).split() os.environ["XLA_FLAGS"] = " ".join( ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags @@ -89,7 +90,7 @@ def set_host_device_count(n): @contextmanager -def optional(condition, context_manager): +def optional(condition: bool, context_manager) -> Generator: """ Optionally wrap inside `context_manager` if condition is `True`. """ @@ -101,7 +102,7 @@ def optional(condition, context_manager): @contextmanager -def control_flow_prims_disabled(): +def control_flow_prims_disabled() -> Generator: global _DISABLE_CONTROL_FLOW_PRIM stored_flag = _DISABLE_CONTROL_FLOW_PRIM try: @@ -111,14 +112,16 @@ def control_flow_prims_disabled(): _DISABLE_CONTROL_FLOW_PRIM = stored_flag -def maybe_jit(fn, *args, **kwargs): +def maybe_jit(fn: Callable, *args, **kwargs) -> Callable: if _DISABLE_CONTROL_FLOW_PRIM: return fn else: return jit(fn, *args, **kwargs) -def cond(pred, true_operand, true_fun, false_operand, false_fun): +def cond( + pred: bool, true_operand, true_fun: Callable, false_operand, false_fun: Callable +) -> Any: if _DISABLE_CONTROL_FLOW_PRIM: if pred: return true_fun(true_operand) @@ -128,7 +131,7 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun): return lax.cond(pred, true_operand, true_fun, false_operand, false_fun) -def while_loop(cond_fun, body_fun, init_val): +def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any: if _DISABLE_CONTROL_FLOW_PRIM: val = init_val while cond_fun(val): @@ -148,7 +151,7 @@ def fori_loop(lower, upper, body_fun, init_val): return lax.fori_loop(lower, upper, body_fun, init_val) -def is_prng_key(key): +def is_prng_key(key: jax.Array) -> bool: try: if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): return key.shape == () @@ -190,7 +193,7 @@ def _wrapped(fn): return _wrapped -def progress_bar_factory(num_samples, num_chains): +def progress_bar_factory(num_samples: int, num_chains: int) -> Callable: """Factory that builds a progress bar decorator along with the `set_tqdm_description` and `close_tqdm` functions """ @@ -254,7 +257,7 @@ def _update_progress_bar(iter_num, chain): ) return chain - def progress_bar_fori_loop(func): + def progress_bar_fori_loop(func: Callable) -> Callable: """Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`. Note that `body_fun` must be looping over a tuple who's first element is `np.arange(num_samples)`. This means that `iter_num` is the current iteration number @@ -272,13 +275,13 @@ def wrapper_progress_bar(i, vals): def fori_collect( - lower, - upper, - body_fun, - init_val, - transform=identity, - progbar=True, - return_last_val=False, + lower: int, + upper: int, + body_fun: Callable, + init_val: Any, + transform: Callable = identity, + progbar: bool = True, + return_last_val: bool = False, collection_size=None, thinning=1, **progbar_opts, @@ -404,7 +407,9 @@ def loop_fn(collection): return (collection, last_val) if return_last_val else collection -def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): +def soft_vmap( + fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None +) -> Any: """ Vectorizing map that maps a function `fn` over `batch_ndims` leading axes of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions @@ -457,11 +462,11 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): def format_shapes( - trace, + trace: dict, *, - compute_log_prob=False, - title="Trace Shapes:", - last_site=None, + compute_log_prob: bool = False, + title: str = "Trace Shapes:", + last_site: str | None = None, ): """ Given the trace of a function, returns a string showing a table of the shapes of @@ -510,7 +515,7 @@ def model(*args, **kwargs): batch_shape = getattr(site["fn"], "batch_shape", ()) event_shape = getattr(site["fn"], "event_shape", ()) rows.append( - [f"{name} dist", None] + [f"{name} dist", None] # type: ignore[arg-type] + [str(size) for size in batch_shape] + ["|", None] + [str(size) for size in event_shape] @@ -522,7 +527,7 @@ def model(*args, **kwargs): batch_shape = shape[: len(shape) - event_dim] event_shape = shape[len(shape) - event_dim :] rows.append( - ["value", None] + ["value", None] # type: ignore[arg-type] + [str(size) for size in batch_shape] + ["|", None] + [str(size) for size in event_shape] @@ -534,14 +539,14 @@ def model(*args, **kwargs): ): batch_shape = getattr(site["fn"].log_prob(site["value"]), "shape", ()) rows.append( - ["log_prob", None] + ["log_prob", None] # type: ignore[arg-type] + [str(size) for size in batch_shape] + ["|", None] ) elif site["type"] == "plate": shape = getattr(site["value"], "shape", ()) rows.append( - [f"{name} plate", None] + [str(size) for size in shape] + ["|", None] + [f"{name} plate", None] + [str(size) for size in shape] + ["|", None] # type: ignore[arg-type] ) if name == last_site: @@ -551,7 +556,7 @@ def model(*args, **kwargs): # TODO: follow pyro.util.check_site_shape logics for more complete validation -def _validate_model(model_trace, plate_warning="loose"): +def _validate_model(model_trace: dict, plate_warning: str = "loose") -> None: # TODO: Consider exposing global configuration for those strategies. assert plate_warning in ["loose", "strict", "error"] enum_dims = set( @@ -591,7 +596,7 @@ def _validate_model(model_trace, plate_warning="loose"): warnings.warn(message, stacklevel=find_stack_level()) -def check_model_guide_match(model_trace, guide_trace): +def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None: """ Checks the following assumptions: diff --git a/pyproject.toml b/pyproject.toml index e6859c122..37d24a27c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,12 +106,16 @@ doctest_optionflags = [ [tool.mypy] ignore_errors = true ignore_missing_imports = true +plugins = ["numpy.typing.mypy_plugin"] [[tool.mypy.overrides]] module = [ - "numpyro.contrib.control_flow.*", # types missing - "numpyro.contrib.funsor.*", # types missing + "numpyro.contrib.control_flow.*", # types missing + "numpyro.contrib.funsor.*", # types missing "numpyro.contrib.hsgp.*", "numpyro.contrib.stochastic_support.*", + "numpyro.diagnostics.*", + "numpyro.patch.*", + "numpyro.util.*", ] ignore_errors = false