Skip to content

Commit

Permalink
Proof of Concept: Types and MyPy (#1906)
Browse files Browse the repository at this point in the history
* mypy init got hsgp module

* mypy some modules

* rename step

* fix missing default

* update pre-commit and small typo (caught by codespell)
  • Loading branch information
juanitorduz authored Nov 13, 2024
1 parent a313a6e commit d55d209
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
- name: Lint with ruff
- name: Lint with mypy and ruff
run: |
make lint
- name: Build documentation
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ repos:
language: system
files: "(.py$)|(.*.ipynb$)"

- id: mypy
name: mypy
language: python
entry: mypy --install-types --non-interactive
files: ^numpyro/


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ lint: FORCE
ruff check .
ruff format . --check
python scripts/update_headers.py --check
mypy --install-types --non-interactive numpyro

license: FORCE
python scripts/update_headers.py
Expand Down
3 changes: 2 additions & 1 deletion numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Any, Callable

from jax import device_put, lax

Expand Down Expand Up @@ -72,7 +73,7 @@ def cond_wrapper(
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)


def cond(pred, true_fun, false_fun, operand):
def cond(pred: bool, true_fun: Callable, false_fun: Callable, operand: Any) -> Any:
"""
This primitive conditionally applies ``true_fun`` or ``false_fun``. See
:func:`jax.lax.cond` for more information.
Expand Down
17 changes: 14 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -278,14 +279,17 @@ def scan_wrapper(
length,
reverse,
rng_key=None,
substitute_stack=[],
substitute_stack=None,
enum=False,
history=1,
first_available_dim=None,
):
if length is None:
length = jnp.shape(jax.tree.flatten(xs)[0][0])[0]

if substitute_stack is None:
substitute_stack = []

if enum and history > 0:
return scan_enum( # TODO: replay for enum
f,
Expand Down Expand Up @@ -339,7 +343,14 @@ def body_fn(wrapped_carry, x):
return last_carry, (pytree_trace, ys)


def scan(f, init, xs, length=None, reverse=False, history=1):
def scan(
f: Callable,
init,
xs,
length: int | None = None,
reverse: bool = False,
history: int = 1,
):
"""
This primitive scans a function over the leading array axes of
`xs` while carrying along state. See :func:`jax.lax.scan` for more
Expand Down Expand Up @@ -433,7 +444,7 @@ def g(*args, **kwargs):
:param init: the initial carrying state
:param xs: the values over which we scan along the leading axis. This can
be any JAX pytree (e.g. list/dict of arrays).
:param length: optional value specifying the length of `xs`
:param int | None length: optional value specifying the length of `xs`
but can be used when `xs` is an empty pytree (e.g. None)
:param bool reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse
Expand Down
8 changes: 6 additions & 2 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import numpyro.distributions as dist


def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
def _non_centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
def _centered_approximation(
phi: ArrayImpl, spd: ArrayImpl, m: int | list[int]
) -> ArrayImpl:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def align_param(dim, param):

def spectral_density_squared_exponential(
dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
) -> ArrayImpl:
"""
Spectral density of the squared exponential kernel.
Expand All @@ -46,7 +46,7 @@ def spectral_density_squared_exponential(
:param float alpha: amplitude
:param float length: length scale
:return: spectral density value
:rtype: float
:rtype: ArrayImpl
"""
length = align_param(dim, length)
c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def __init__(
transition_matrix.ndim == 2
), "`transition_matrix` argument should be a square matrix"
self.transition_matrix = transition_matrix
# Expand the covariance/presicion/scale matrices to the right number of steps.
# Expand the covariance/precision/scale matrices to the right number of steps.
args = {
"covariance_matrix": covariance_matrix,
"precision_matrix": precision_matrix,
Expand Down
3 changes: 2 additions & 1 deletion numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
from functools import partial
from typing import Callable, Optional
Expand Down Expand Up @@ -931,7 +932,7 @@ def __init__(
guide: Optional[Callable] = None,
params: Optional[dict] = None,
num_samples: Optional[int] = None,
return_sites: Optional[list[str]] = None,
return_sites: Optional[Sequence[str]] = None,
infer_discrete: bool = False,
parallel: bool = False,
batch_ndims: Optional[int] = None,
Expand Down
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,15 @@ doctest_optionflags = [
"NORMALIZE_WHITESPACE",
"IGNORE_EXCEPTION_DETAIL",
]

[tool.mypy]
ignore_errors = true
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"numpyro.contrib.control_flow.*", # types missing
"numpyro.contrib.funsor.*", # types missing
"numpyro.contrib.hsgp.*",
]
ignore_errors = false
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"test": [
"importlib-metadata<5.0",
"ruff>=0.1.8",
"mypy>=1.13",
"pytest>=4.1",
"pyro-api>=0.1.1",
"scikit-learn",
Expand Down

0 comments on commit d55d209

Please sign in to comment.