Skip to content

Commit

Permalink
types
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 17, 2024
1 parent 66921bd commit 00d1994
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 48 deletions.
24 changes: 14 additions & 10 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]


def _compute_chain_variance_stats(x):
def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -40,7 +40,7 @@ def _compute_chain_variance_stats(x):
return var_within, var_estimator


def gelman_rubin(x):
def gelman_rubin(x: np.ndarray) -> np.ndarray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -59,7 +59,7 @@ def gelman_rubin(x):
return rhat


def split_gelman_rubin(x):
def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -78,7 +78,7 @@ def split_gelman_rubin(x):
return split_rhat


def _fft_next_fast_len(target):
def _fft_next_fast_len(target: int) -> int:
# find the smallest number >= N such that the only divisors are 2, 3, 5
# works just like scipy.fftpack.next_fast_len
if target <= 2:
Expand All @@ -96,7 +96,7 @@ def _fft_next_fast_len(target):
target += 1


def autocorrelation(x, axis=0, bias=True):
def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
"""
Computes the autocorrelation of samples at dimension ``axis``.
Expand Down Expand Up @@ -140,7 +140,7 @@ def autocorrelation(x, axis=0, bias=True):
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x, axis=0, bias=True):
def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
"""
Computes the autocovariance of samples at dimension ``axis``.
Expand All @@ -153,7 +153,7 @@ def autocovariance(x, axis=0, bias=True):
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x, bias=True):
def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -201,7 +201,7 @@ def effective_sample_size(x, bias=True):
return n_eff


def hpdi(x, prob=0.90, axis=0):
def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -229,7 +229,9 @@ def hpdi(x, prob=0.90, axis=0):
return np.concatenate([hpd_left, hpd_right], axis=axis)


def summary(samples, prob=0.90, group_by_chain=True):
def summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
) -> dict:
"""
Returns a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Expand Down Expand Up @@ -279,7 +281,9 @@ def summary(samples, prob=0.90, group_by_chain=True):
return summary_dict


def print_summary(samples, prob=0.90, group_by_chain=True):
def print_summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
posterior. The diagnostics displayed are mean, standard deviation, median,
Expand Down
4 changes: 3 additions & 1 deletion numpyro/patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from types import ModuleType
from typing import Callable


def patch_dependency(target, root_module):
def patch_dependency(target: str, root_module: ModuleType) -> Callable:
parts = target.split(".")
assert parts[0] == root_module.__name__
module = root_module
Expand Down
74 changes: 39 additions & 35 deletions numpyro/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
Expand All @@ -10,6 +9,7 @@
import random
import re
from threading import Lock
from typing import Any, Callable, Generator
import warnings

import numpy as np
Expand All @@ -26,7 +26,7 @@
_CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3'


def set_rng_seed(rng_seed):
def set_rng_seed(rng_seed: int | None = None) -> None:
"""
Initializes internal state for the Python and NumPy random number generators.
Expand All @@ -36,19 +36,19 @@ def set_rng_seed(rng_seed):
np.random.seed(rng_seed)


def enable_x64(use_x64=True):
def enable_x64(use_x64: bool = True) -> None:
"""
Changes the default array type to use 64 bit precision as in NumPy.
:param bool use_x64: when `True`, JAX arrays will use 64 bits by default;
else 32 bits.
"""
if not use_x64:
use_x64 = os.getenv("JAX_ENABLE_X64", 0)
jax.config.update("jax_enable_x64", bool(use_x64))
use_x64 = bool(os.getenv("JAX_ENABLE_X64", 0))
jax.config.update("jax_enable_x64", use_x64)


def set_platform(platform=None):
def set_platform(platform: str | None = None) -> None:
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand All @@ -60,7 +60,7 @@ def set_platform(platform=None):
jax.config.update("jax_platform_name", platform)


def set_host_device_count(n):
def set_host_device_count(n: int) -> None:
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA
that there are `n` host (CPU) devices available to use. As a consequence, this
Expand All @@ -79,17 +79,17 @@ def set_host_device_count(n):
:param int n: number of CPU devices to use.
"""
xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags_str = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
r"--xla_force_host_platform_device_count=\S+", "", xla_flags_str
).split()
os.environ["XLA_FLAGS"] = " ".join(
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
)


@contextmanager
def optional(condition, context_manager):
def optional(condition: bool, context_manager) -> Generator:
"""
Optionally wrap inside `context_manager` if condition is `True`.
"""
Expand All @@ -101,7 +101,7 @@ def optional(condition, context_manager):


@contextmanager
def control_flow_prims_disabled():
def control_flow_prims_disabled() -> Generator:
global _DISABLE_CONTROL_FLOW_PRIM
stored_flag = _DISABLE_CONTROL_FLOW_PRIM
try:
Expand All @@ -111,14 +111,16 @@ def control_flow_prims_disabled():
_DISABLE_CONTROL_FLOW_PRIM = stored_flag


def maybe_jit(fn, *args, **kwargs):
def maybe_jit(fn: Callable, *args, **kwargs) -> Callable:
if _DISABLE_CONTROL_FLOW_PRIM:
return fn
else:
return jit(fn, *args, **kwargs)


def cond(pred, true_operand, true_fun, false_operand, false_fun):
def cond(
pred: bool, true_operand, true_fun: Callable, false_operand, false_fun: Callable
) -> Any:
if _DISABLE_CONTROL_FLOW_PRIM:
if pred:
return true_fun(true_operand)
Expand All @@ -128,7 +130,7 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)


def while_loop(cond_fun, body_fun, init_val):
def while_loop(cond_fun: Callable, body_fun: Callable, init_val: Any) -> Any:
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
while cond_fun(val):
Expand All @@ -148,7 +150,7 @@ def fori_loop(lower, upper, body_fun, init_val):
return lax.fori_loop(lower, upper, body_fun, init_val)


def is_prng_key(key):
def is_prng_key(key: jax.Array) -> bool:
try:
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return key.shape == ()
Expand Down Expand Up @@ -190,7 +192,7 @@ def _wrapped(fn):
return _wrapped


def progress_bar_factory(num_samples, num_chains):
def progress_bar_factory(num_samples: int, num_chains: int) -> Callable:
"""Factory that builds a progress bar decorator along
with the `set_tqdm_description` and `close_tqdm` functions
"""
Expand Down Expand Up @@ -254,7 +256,7 @@ def _update_progress_bar(iter_num, chain):
)
return chain

def progress_bar_fori_loop(func):
def progress_bar_fori_loop(func: Callable) -> Callable:
"""Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`.
Note that `body_fun` must be looping over a tuple who's first element is `np.arange(num_samples)`.
This means that `iter_num` is the current iteration number
Expand All @@ -272,13 +274,13 @@ def wrapper_progress_bar(i, vals):


def fori_collect(
lower,
upper,
body_fun,
init_val,
transform=identity,
progbar=True,
return_last_val=False,
lower: int,
upper: int,
body_fun: Callable,
init_val: Any,
transform: Callable = identity,
progbar: bool = True,
return_last_val: bool = False,
collection_size=None,
thinning=1,
**progbar_opts,
Expand Down Expand Up @@ -404,7 +406,9 @@ def loop_fn(collection):
return (collection, last_val) if return_last_val else collection


def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
def soft_vmap(
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None
) -> Any:
"""
Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
Expand Down Expand Up @@ -457,11 +461,11 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):


def format_shapes(
trace,
trace: dict,
*,
compute_log_prob=False,
title="Trace Shapes:",
last_site=None,
compute_log_prob: bool = False,
title: str = "Trace Shapes:",
last_site: str | None = None,
):
"""
Given the trace of a function, returns a string showing a table of the shapes of
Expand Down Expand Up @@ -510,7 +514,7 @@ def model(*args, **kwargs):
batch_shape = getattr(site["fn"], "batch_shape", ())
event_shape = getattr(site["fn"], "event_shape", ())
rows.append(
[f"{name} dist", None]
[f"{name} dist", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
+ [str(size) for size in event_shape]
Expand All @@ -522,7 +526,7 @@ def model(*args, **kwargs):
batch_shape = shape[: len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim :]
rows.append(
["value", None]
["value", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
+ [str(size) for size in event_shape]
Expand All @@ -534,14 +538,14 @@ def model(*args, **kwargs):
):
batch_shape = getattr(site["fn"].log_prob(site["value"]), "shape", ())
rows.append(
["log_prob", None]
["log_prob", None] # type: ignore[arg-type]
+ [str(size) for size in batch_shape]
+ ["|", None]
)
elif site["type"] == "plate":
shape = getattr(site["value"], "shape", ())
rows.append(
[f"{name} plate", None] + [str(size) for size in shape] + ["|", None]
[f"{name} plate", None] + [str(size) for size in shape] + ["|", None] # type: ignore[arg-type]
)

if name == last_site:
Expand All @@ -551,7 +555,7 @@ def model(*args, **kwargs):


# TODO: follow pyro.util.check_site_shape logics for more complete validation
def _validate_model(model_trace, plate_warning="loose"):
def _validate_model(model_trace: dict, plate_warning: str = "loose") -> None:
# TODO: Consider exposing global configuration for those strategies.
assert plate_warning in ["loose", "strict", "error"]
enum_dims = set(
Expand Down Expand Up @@ -591,7 +595,7 @@ def _validate_model(model_trace, plate_warning="loose"):
warnings.warn(message, stacklevel=find_stack_level())


def check_model_guide_match(model_trace, guide_trace):
def check_model_guide_match(model_trace: dict, guide_trace: dict) -> None:
"""
Checks the following assumptions:
Expand Down
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ doctest_optionflags = [
[tool.mypy]
ignore_errors = true
ignore_missing_imports = true
plugins = ["numpy.typing.mypy_plugin"]

[[tool.mypy.overrides]]
module = [
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
"numpyro.contrib.stochastic_support.*",
"numpyro.diagnostics.*",
"numpyro.patch.*",
"numpyro.util.*",
]
ignore_errors = false

0 comments on commit 00d1994

Please sign in to comment.