Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Types] Diagnostics and Utils #1912

Merged
merged 2 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]


def _compute_chain_variance_stats(x):
def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these functions assume true numpy arrays, so this was the only way to make it work without adding many changes.

# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -40,7 +40,7 @@ def _compute_chain_variance_stats(x):
return var_within, var_estimator


def gelman_rubin(x):
def gelman_rubin(x: np.ndarray) -> np.ndarray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -59,7 +59,7 @@ def gelman_rubin(x):
return rhat


def split_gelman_rubin(x):
def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -78,7 +78,7 @@ def split_gelman_rubin(x):
return split_rhat


def _fft_next_fast_len(target):
def _fft_next_fast_len(target: int) -> int:
# find the smallest number >= N such that the only divisors are 2, 3, 5
# works just like scipy.fftpack.next_fast_len
if target <= 2:
Expand All @@ -96,7 +96,7 @@ def _fft_next_fast_len(target):
target += 1


def autocorrelation(x, axis=0, bias=True):
def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
"""
Computes the autocorrelation of samples at dimension ``axis``.

Expand Down Expand Up @@ -140,7 +140,7 @@ def autocorrelation(x, axis=0, bias=True):
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x, axis=0, bias=True):
def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
"""
Computes the autocovariance of samples at dimension ``axis``.

Expand All @@ -153,7 +153,7 @@ def autocovariance(x, axis=0, bias=True):
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x, bias=True):
def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -201,7 +201,7 @@ def effective_sample_size(x, bias=True):
return n_eff


def hpdi(x, prob=0.90, axis=0):
def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -229,7 +229,9 @@ def hpdi(x, prob=0.90, axis=0):
return np.concatenate([hpd_left, hpd_right], axis=axis)


def summary(samples, prob=0.90, group_by_chain=True):
def summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
) -> dict:
"""
Returns a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Expand Down Expand Up @@ -279,7 +281,9 @@ def summary(samples, prob=0.90, group_by_chain=True):
return summary_dict


def print_summary(samples, prob=0.90, group_by_chain=True):
def print_summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Expand Down
4 changes: 3 additions & 1 deletion numpyro/patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from types import ModuleType
from typing import Callable


def patch_dependency(target, root_module):
def patch_dependency(target: str, root_module: ModuleType) -> Callable:
parts = target.split(".")
assert parts[0] == root_module.__name__
module = root_module
Expand Down
74 changes: 39 additions & 35 deletions numpyro/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
Expand All @@ -10,6 +9,7 @@
import random
import re
from threading import Lock
from typing import Any, Callable, Generator
import warnings

import numpy as np
Expand All @@ -26,7 +26,7 @@
_CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3'


def set_rng_seed(rng_seed):
def set_rng_seed(rng_seed: int | None = None) -> None:
"""
Initializes internal state for the Python and NumPy random number generators.

Expand All @@ -36,19 +36,19 @@ def set_rng_seed(rng_seed):
np.random.seed(rng_seed)


def enable_x64(use_x64=True):
def enable_x64(use_x64: bool = True) -> None:
"""
Changes the default array type to use 64 bit precision as in NumPy.

:param bool use_x64: when `True`, JAX arrays will use 64 bits by default;
else 32 bits.
"""
if not use_x64:
use_x64 = os.getenv("JAX_ENABLE_X64", 0)
jax.config.update("jax_enable_x64", bool(use_x64))
use_x64 = bool(os.getenv("JAX_ENABLE_X64", 0))
jax.config.update("jax_enable_x64", use_x64)


def set_platform(platform=None):
def set_platform(platform: str | None = None) -> None:
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand All @@ -60,7 +60,7 @@ def set_platform(platform=None):
jax.config.update("jax_platform_name", platform)


def set_host_device_count(n):
def set_host_device_count(n: int) -> None:
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
that there are `n` host (CPU) devices available to use. As a consequence, this
Expand All @@ -79,17 +79,17 @@ def set_host_device_count(n):

:param int n: number of CPU devices to use.
"""
xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags_str = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
r"--xla_force_host_platform_device_count=\S+", "", xla_flags_str
).split()
os.environ["XLA_FLAGS"] = " ".join(
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
)


@contextmanager
def optional(condition, context_manager):
def optional(condition: bool, context_manager) -> Generator:
"""
Optionally wrap inside `context_manager` if condition is `True`.
"""
Expand All @@ -101,7 +101,7 @@ def optional(condition, context_manager):


@contextmanager
def control_flow_prims_disabled():
def control_flow_prims_disabled() -> Generator:
global _DISABLE_CONTROL_FLOW_PRIM
stored_flag = _DISABLE_CONTROL_FLOW_PRIM
try:
Expand All @@ -111,14 +111,16 @@ def control_flow_prims_disabled():
_DISABLE_CONTROL_FLOW_PRIM = stored_flag


def maybe_jit(fn, *args, **kwargs):
def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:
if _DISABLE_CONTROL_FLOW_PRIM:
return fn
else:
return jit(fn, *args, **kwargs)


def cond(pred, true_operand, true_fun, false_operand, false_fun):
def cond(
pred: bool, true_operand, true_fun: Callable, false_operand, false_fun: Callable
) -> Any:
if _DISABLE_CONTROL_FLOW_PRIM:
if pred:
return true_fun(true_operand)
Expand All @@ -128,7 +130,7 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)


def while_loop(cond_fun, body_fun, init_val):
def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
while cond_fun(val):
Expand All @@ -148,7 +150,7 @@ def fori_loop(lower, upper, body_fun, init_val):
return lax.fori_loop(lower, upper, body_fun, init_val)


def is_prng_key(key):
def is_prng_key(key: jax.Array) -> bool:
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
Expand Down Expand Up @@ -190,7 +192,7 @@ def _wrapped(fn):
return _wrapped


def progress_bar_factory(num_samples, num_chains):
def progress_bar_factory(num_samples: int, num_chains: int) -> Callable:
"""Factory that builds a progress bar decorator along
with the `set_tqdm_description` and `close_tqdm` functions
"""
Expand Down Expand Up @@ -254,7 +256,7 @@ def _update_progress_bar(iter_num, chain):
)
return chain

def progress_bar_fori_loop(func):
def progress_bar_fori_loop(func: Callable) -> Callable:
"""Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`.
Note that `body_fun` must be looping over a tuple who's first element is `np.arange(num_samples)`.
This means that `iter_num` is the current iteration number
Expand All @@ -272,13 +274,13 @@ def wrapper_progress_bar(i, vals):


def fori_collect(
lower,
upper,
body_fun,
init_val,
transform=identity,
progbar=True,
return_last_val=False,
lower: int,
upper: int,
body_fun: Callable,
init_val: Any,
transform: Callable = identity,
progbar: bool = True,
return_last_val: bool = False,
collection_size=None,
thinning=1,
**progbar_opts,
Expand Down Expand Up @@ -404,7 +406,9 @@ def loop_fn(collection):
return (collection, last_val) if return_last_val else collection


def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
def soft_vmap(
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None
) -> Any:
"""
Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
Expand Down Expand Up @@ -457,11 +461,11 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):


def format_shapes(
trace,
trace: dict,
*,
compute_log_prob=False,
title="Trace Shapes:",
last_site=None,
compute_log_prob: bool = False,
title: str = "Trace Shapes:",
last_site: str | None = None,
):
"""
Given the trace of a function, returns a string showing a table of the shapes of
Expand Down Expand Up @@ -510,7 +514,7 @@ def model(*args, **kwargs):
batch_shape = getattr(site["fn"], "batch_shape", ())
event_shape = getattr(site["fn"], "event_shape", ())
rows.append(
[f"{name} dist", None]
[f"{name} dist", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
+ [str(size) for size in event_shape]
Expand All @@ -522,7 +526,7 @@ def model(*args, **kwargs):
batch_shape = shape[: len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim :]
rows.append(
["value", None]
["value", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
+ [str(size) for size in event_shape]
Expand All @@ -534,14 +538,14 @@ def model(*args, **kwargs):
):
batch_shape = getattr(site["fn"].log_prob(site["value"]), "shape", ())
rows.append(
["log_prob", None]
["log_prob", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
)
elif site["type"] == "plate":
shape = getattr(site["value"], "shape", ())
rows.append(
[f"{name} plate", None] + [str(size) for size in shape] + ["|", None]
[f"{name} plate", None] + [str(size) for size in shape] + ["|", None] # type: ignore[arg-type]
)

if name == last_site:
Expand All @@ -551,7 +555,7 @@ def model(*args, **kwargs):


# TODO: follow pyro.util.check_site_shape logics for more complete validation
def _validate_model(model_trace, plate_warning="loose"):
def _validate_model(model_trace: dict, plate_warning: str = "loose") -> None:
# TODO: Consider exposing global configuration for those strategies.
assert plate_warning in ["loose", "strict", "error"]
enum_dims = set(
Expand Down Expand Up @@ -591,7 +595,7 @@ def _validate_model(model_trace, plate_warning="loose"):
warnings.warn(message, stacklevel=find_stack_level())


def check_model_guide_match(model_trace, guide_trace):
def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:
"""
Checks the following assumptions:

Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ doctest_optionflags = [
[tool.mypy]
ignore_errors = true
ignore_missing_imports = true
plugins = ["numpy.typing.mypy_plugin"]

[[tool.mypy.overrides]]
module = [
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
"numpyro.contrib.stochastic_support.*",
"numpyro.diagnostics.*",
"numpyro.patch.*",
"numpyro.util.*",
]
ignore_errors = false
Loading