Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to the primitives module #1940

Merged
merged 19 commits into from
Dec 23, 2024
Merged
22 changes: 10 additions & 12 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from __future__ import annotations

from typing import cast
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike
Expand All @@ -21,26 +23,22 @@
import numpyro.distributions as dist


def _non_centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
) -> Array:
def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)
return cast(Array, phi @ (spd * beta))


def _centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta
return cast(Array, phi @ beta)


def linear_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True
phi: Array, spd: Array, m: int, non_centered: bool = True
) -> Array:
"""
Linear approximation formula of the Hilbert space Gaussian process.
Expand All @@ -52,10 +50,10 @@ def linear_approximation(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike phi: laplacian eigenfunctions
:param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square
:param Array phi: laplacian eigenfunctions
:param Array spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int | list[int] m: number of eigenfunctions in the approximation
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: Array
Expand Down
19 changes: 10 additions & 9 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union

import numpy as np
import numpy.typing as npt

import jax
from jax import device_get
Expand All @@ -25,7 +26,7 @@
]


def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _compute_chain_variance_stats(x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why NDArray is used instead of ndarray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyPy started complaining while I was adding hints and decided to follow the typing guidelines from numpy https://numpy.org/doc/stable/reference/typing.html :) Then all errors were gone

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. thanks! maybe import NDArray instead of npt? It makes the code easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! changed in 1bd90ea

# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray
return var_within, var_estimator


def gelman_rubin(x: np.ndarray) -> np.ndarray:
def gelman_rubin(x: npt.NDArray) -> npt.NDArray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray:
return rhat


def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
def split_gelman_rubin(x: npt.NDArray) -> npt.NDArray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int:
target += 1


def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocorrelation(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray:
"""
Computes the autocorrelation of samples at dimension ``axis``.

Expand Down Expand Up @@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr
autocorr = autocorr / np.arange(N, 0.0, -1)

with np.errstate(invalid="ignore", divide="ignore"):
autocorr = autocorr / autocorr[..., :1]
autocorr = (autocorr / autocorr[..., :1]).astype(np.float64)
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocovariance(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray:
"""
Computes the autocovariance of samples at dimension ``axis``.

Expand All @@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
def effective_sample_size(x: npt.NDArray, bias: bool = True) -> npt.NDArray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
return n_eff


def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
def hpdi(x: npt.NDArray, prob: float = 0.90, axis: int = 0) -> npt.NDArray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -285,7 +286,7 @@ def summary(


def print_summary(
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, npt.NDArray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
Loading
Loading