Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to the primitives module #1940

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
20 changes: 8 additions & 12 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,22 @@
import numpyro.distributions as dist


def _non_centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why Array is used instead of ArrayLike?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because ArrayLike can be a scalar (see https://jax.readthedocs.io/en/latest/jax.typing.html) but we want them to be true arrays

) -> Array:
def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)
return jnp.asarray(phi @ (spd * beta))


def _centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta
return jnp.asarray(phi @ beta)


def linear_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True
phi: Array, spd: Array, m: int, non_centered: bool = True
) -> Array:
"""
Linear approximation formula of the Hilbert space Gaussian process.
Expand All @@ -52,10 +48,10 @@ def linear_approximation(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike phi: laplacian eigenfunctions
:param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square
:param Array phi: laplacian eigenfunctions
:param Array spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int | list[int] m: number of eigenfunctions in the approximation
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: Array
Expand Down
19 changes: 10 additions & 9 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union

import numpy as np
import numpy.typing as npt

import jax
from jax import device_get
Expand All @@ -25,7 +26,7 @@
]


def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _compute_chain_variance_stats(x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why NDArray is used instead of ndarray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyPy started complaining while I was adding hints and decided to follow the typing guidelines from numpy https://numpy.org/doc/stable/reference/typing.html :) Then all errors were gone

# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray
return var_within, var_estimator


def gelman_rubin(x: np.ndarray) -> np.ndarray:
def gelman_rubin(x: npt.NDArray) -> npt.NDArray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray:
return rhat


def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
def split_gelman_rubin(x: npt.NDArray) -> npt.NDArray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int:
target += 1


def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocorrelation(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray:
"""
Computes the autocorrelation of samples at dimension ``axis``.

Expand Down Expand Up @@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr
autocorr = autocorr / np.arange(N, 0.0, -1)

with np.errstate(invalid="ignore", divide="ignore"):
autocorr = autocorr / autocorr[..., :1]
autocorr = (autocorr / autocorr[..., :1]).astype(np.float64)
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocovariance(x: npt.NDArray, axis: int = 0, bias: bool = True) -> npt.NDArray:
"""
Computes the autocovariance of samples at dimension ``axis``.

Expand All @@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
def effective_sample_size(x: npt.NDArray, bias: bool = True) -> npt.NDArray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
return n_eff


def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
def hpdi(x: npt.NDArray, prob: float = 0.90, axis: int = 0) -> npt.NDArray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -285,7 +286,7 @@ def summary(


def print_summary(
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, npt.NDArray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
45 changes: 39 additions & 6 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from collections import OrderedDict
from contextlib import contextmanager
import functools
import inspect
from typing import Protocol, runtime_checkable
import warnings

import numpy as np

import jax
from jax import lax, tree_util
from jax import Array, lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform
from numpyro.distributions.util import (
Expand Down Expand Up @@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None:
raise RuntimeError("Cannot validate arguments inside jitted code.")

@property
def batch_shape(self):
def batch_shape(self) -> tuple[int, ...]:
"""
Returns the shape over which the distribution parameters are batched.

Expand All @@ -280,7 +281,7 @@ def batch_shape(self):
return self._batch_shape

@property
def event_shape(self):
def event_shape(self) -> tuple[int, ...]:
"""
Returns the shape of a single sample from the distribution without
batching.
Expand All @@ -291,15 +292,15 @@ def event_shape(self):
return self._event_shape

@property
def event_dim(self):
def event_dim(self) -> int:
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)

@property
def has_rsample(self):
def has_rsample(self) -> bool:
return set(self.reparametrized_params) == set(self.arg_constraints)

def rsample(self, key, sample_shape=()):
Expand Down Expand Up @@ -563,6 +564,38 @@ def is_discrete(self):
return self.support.is_discrete


@runtime_checkable
class DistributionLike(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why Protocol is needed?

Copy link
Contributor Author

@juanitorduz juanitorduz Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This protocol is needed to provide static type checking for functions that accept
distribution objects. From what I see, NumPyro supports multiple distribution implementations:

  • numpyro.distributions.Distribution
  • funsor.Funsor
  • tensorflow_probability.distributions.Distribution

Rather than checking for each specific type, this protocol defines the common interface
that all distribution objects must implement. This allows static type checkers to verify
that any distribution object passed to a function has the required methods and properties,
regardless of which implementation it uses.

If we remove it then we will get many MyPy errors 😅 .

I have seen this approach in other projects so this is why I suggest it :)

Copy link
Contributor Author

@juanitorduz juanitorduz Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And tensorflow probability and funsor are optional dependencies so we can not use Union to create a common type 🙃)

"""A protocol for typing distributions.

Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
or tensorflow_probability.distributions.Distribution.
"""

@property
def batch_shape(self) -> tuple[int, ...]: ...

@property
def event_shape(self) -> tuple[int, ...]: ...

@property
def event_dim(self) -> int: ...

def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

@property
def mean(self) -> ArrayLike: ...

@property
def variance(self) -> ArrayLike: ...

def cdf(self, value: ArrayLike) -> ArrayLike: ...

def icdf(self, q: ArrayLike) -> ArrayLike: ...


class ExpandedDistribution(Distribution):
arg_constraints = {}
pytree_data_fields = ("base_dist",)
Expand Down
Loading
Loading