diff --git a/numpyro/contrib/control_flow/cond.py b/numpyro/contrib/control_flow/cond.py index 1931d490c..27a0dcee8 100644 --- a/numpyro/contrib/control_flow/cond.py +++ b/numpyro/contrib/control_flow/cond.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial +from typing import Any, Callable from jax import device_put, lax @@ -72,7 +73,7 @@ def cond_wrapper( return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand) -def cond(pred, true_fun, false_fun, operand): +def cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) -> Any: """ This primitive conditionally applies ``true_fun`` or ``false_fun``. See :func:`jax.lax.cond` for more information. diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 7a2689cf1..4905ba24b 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -3,6 +3,7 @@ from collections import OrderedDict from functools import partial +from typing import Callable import jax from jax import device_put, lax, random @@ -278,7 +279,7 @@ def scan_wrapper( length, reverse, rng_key=None, - substitute_stack=[], + substitute_stack=None, enum=False, history=1, first_available_dim=None, @@ -339,7 +340,14 @@ def body_fn(wrapped_carry, x): return last_carry, (pytree_trace, ys) -def scan(f, init, xs, length=None, reverse=False, history=1): +def scan( + f: Callable, + init, + xs, + length: int | None = None, + reverse: bool = False, + history: int = 1, +): """ This primitive scans a function over the leading array axes of `xs` while carrying along state. See :func:`jax.lax.scan` for more @@ -433,7 +441,7 @@ def g(*args, **kwargs): :param init: the initial carrying state :param xs: the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays). - :param length: optional value specifying the length of `xs` + :param int | None length: optional value specifying the length of `xs` but can be used when `xs` is an empty pytree (e.g. None) :param bool reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index d0596c1e0..1ff0b5c5a 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import namedtuple +from collections.abc import Sequence from contextlib import contextmanager from functools import partial from typing import Callable, Optional @@ -931,7 +932,7 @@ def __init__( guide: Optional[Callable] = None, params: Optional[dict] = None, num_samples: Optional[int] = None, - return_sites: Optional[list[str]] = None, + return_sites: Optional[Sequence[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, diff --git a/pyproject.toml b/pyproject.toml index c9d1668f7..847758329 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,8 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - #"numpyro.contrib.einstein.*", + "numpyro.contrib.control_flow.*", # types missing + "numpyro.contrib.funsor.*", # types missing "numpyro.contrib.hsgp.*", ] ignore_errors = false