Skip to content

Commit

Permalink
mypy some modules
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 12, 2024
1 parent ea152d5 commit 643dc21
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Any, Callable

from jax import device_put, lax

Expand Down Expand Up @@ -72,7 +73,7 @@ def cond_wrapper(
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)


def cond(pred, true_fun, false_fun, operand):
def cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) -> Any:
"""
This primitive conditionally applies ``true_fun`` or ``false_fun``. See
:func:`jax.lax.cond` for more information.
Expand Down
14 changes: 11 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -278,7 +279,7 @@ def scan_wrapper(
length,
reverse,
rng_key=None,
substitute_stack=[],
substitute_stack=None,
enum=False,
history=1,
first_available_dim=None,
Expand Down Expand Up @@ -339,7 +340,14 @@ def body_fn(wrapped_carry, x):
return last_carry, (pytree_trace, ys)


def scan(f, init, xs, length=None, reverse=False, history=1):
def scan(
f: Callable,
init,
xs,
length: int | None = None,
reverse: bool = False,
history: int = 1,
):
"""
This primitive scans a function over the leading array axes of
`xs` while carrying along state. See :func:`jax.lax.scan` for more
Expand Down Expand Up @@ -433,7 +441,7 @@ def g(*args, **kwargs):
:param init: the initial carrying state
:param xs: the values over which we scan along the leading axis. This can
be any JAX pytree (e.g. list/dict of arrays).
:param length: optional value specifying the length of `xs`
:param int | None length: optional value specifying the length of `xs`
but can be used when `xs` is an empty pytree (e.g. None)
:param bool reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse
Expand Down
3 changes: 2 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
from functools import partial
from typing import Callable, Optional
Expand Down Expand Up @@ -931,7 +932,7 @@ def __init__(
guide: Optional[Callable] = None,
params: Optional[dict] = None,
num_samples: Optional[int] = None,
return_sites: Optional[list[str]] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
batch_ndims: Optional[int] = None,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
#"numpyro.contrib.einstein.*",
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
]
ignore_errors = false

0 comments on commit 643dc21

Please sign in to comment.